Skip to content

Commit

Permalink
fix(amazonq): Adding backoff and retry for export result archive stre…
Browse files Browse the repository at this point in the history
…aming API. (#5320)

* Adding backoff and retry for export result archive.
  • Loading branch information
ashishrp-aws authored Feb 5, 2025
1 parent 8318155 commit ac43387
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients

import com.intellij.testFramework.RuleChain
import com.intellij.testFramework.replaceService
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doAnswer
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.stub
Expand All @@ -20,6 +23,7 @@ import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStrea
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveRequest
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
Expand Down Expand Up @@ -81,4 +85,156 @@ class AmazonQStreamingClientTest : AmazonQTestBase() {
verify(streamingBearerClient).exportResultArchive(requestCaptor.capture(), handlerCaptor.capture())
}
}

@Test
fun `verify retry on ValidationException`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
if (attemptCount <= 2) {
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
} else {
CompletableFuture.completedFuture(mock())
}
}
}

amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})

assertThat(attemptCount).isEqualTo(3)
}

@Test
fun `verify retry gives up after max attempts`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(3)
assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
}

@Test
fun `verify no retry on non-retryable exception`(): Unit = runBlocking {
var attemptCount = 0

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture.runAsync {
throw IllegalArgumentException("Non-retryable error")
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(1)
assertThat(thrown)
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Non-retryable error")
}

@Test
fun `verify backoff timing between retries`(): Unit = runBlocking {
var lastAttemptTime = 0L
var minBackoffObserved = Long.MAX_VALUE
var maxBackoffObserved = 0L

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
val currentTime = System.currentTimeMillis()
if (lastAttemptTime > 0) {
val backoffTime = currentTime - lastAttemptTime
minBackoffObserved = minOf(minBackoffObserved, backoffTime)
maxBackoffObserved = maxOf(maxBackoffObserved, backoffTime)
}
lastAttemptTime = currentTime

CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(minBackoffObserved).isGreaterThanOrEqualTo(100)
assertThat(maxBackoffObserved).isLessThanOrEqualTo(10000)
}

@Test
fun `verify onError callback is called with final exception`(): Unit = runBlocking {
var errorCaught: Exception? = null

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive(
"test-id",
ExportIntent.TRANSFORMATION,
null,
{ errorCaught = it },
{}
)
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(errorCaught).isEqualTo(VALIDATION_EXCEPTION)
}

private suspend fun catchCoroutineException(block: suspend () -> Unit): Throwable {
try {
block()
error("Expected exception was not thrown")
} catch (e: Throwable) {
return e
}
}

companion object {
private val VALIDATION_EXCEPTION = ValidationException.builder()
.message("Resource validation failed")
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import software.amazon.awssdk.utils.IoUtils
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.services.amazonq.RetryableOperation
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.APPLICATION_ZIP
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.AWS_KMS
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.CONTENT_MD5
Expand Down Expand Up @@ -208,41 +209,3 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant
else -> message("testgen.message.failed")
}
}

class RetryableOperation<T> {
private var attempts = 0
private var currentDelay = INITIAL_DELAY
private var lastException: Exception? = null

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean,
errorHandler: (Exception, Int) -> Nothing,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
lastException = e

attempts++
if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) {
Thread.sleep(currentDelay)
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
continue
}

errorHandler(e, attempts)
}
}

// This line should never be reached due to errorHandler throwing exception
throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package software.aws.toolkits.jetbrains.services.amazonq

import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import software.amazon.awssdk.core.exception.RetryableException
import kotlin.random.Random

class RetryableOperation<T> {
private var attempts = 0
private var currentDelay = INITIAL_DELAY

private fun getJitteredDelay(): Long {
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
return (currentDelay * (0.5 + Random.nextDouble(0.5))).toLong()
}

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: ((Exception, Int) -> Nothing),
): T = runBlocking {
executeSuspend(operation, isRetryable, errorHandler)
}

suspend fun executeSuspend(
operation: suspend () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: (suspend (Exception, Int) -> Nothing),
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
attempts++
if (attempts >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
errorHandler.invoke(e, attempts)
}
delay(getJitteredDelay())
}
}

throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L
private const val MAX_BACKOFF = 10000L
private const val MAX_RETRY_ATTEMPTS = 3
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import kotlinx.coroutines.future.await
import software.amazon.awssdk.core.exception.SdkException
import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStreamingAsyncClient
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import software.aws.toolkits.jetbrains.services.amazonq.RetryableOperation
import java.time.Instant
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicReference

@Service(Service.Level.PROJECT)
Expand Down Expand Up @@ -54,30 +59,45 @@ class AmazonQStreamingClient(private val project: Project) {
val checksum = AtomicReference("")

try {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
RetryableOperation<Unit>().executeSuspend(
operation = {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
)
result.await()
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
isRetryable = { e ->
when (e) {
is ValidationException,
is ThrottlingException,
is SdkException,
is TimeoutException,
-> true
else -> false
}
},
errorHandler = { e, attempts ->
onError(e)
throw e
}
)
result.await()
} catch (e: Exception) {
onError(e)
throw e
} finally {
onStreamingFinished(startTime)
}
Expand Down

0 comments on commit ac43387

Please sign in to comment.