Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(amazonq): Adding backoff and retry for export result archive streaming API. #5320

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients

import com.intellij.testFramework.RuleChain
import com.intellij.testFramework.replaceService
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doAnswer
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.stub
Expand All @@ -20,6 +23,7 @@ import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStrea
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveRequest
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.amazon.awssdk.services.ssooidc.SsoOidcClient
import software.aws.toolkits.core.TokenConnectionSettings
import software.aws.toolkits.core.utils.test.aString
Expand Down Expand Up @@ -81,4 +85,156 @@ class AmazonQStreamingClientTest : AmazonQTestBase() {
verify(streamingBearerClient).exportResultArchive(requestCaptor.capture(), handlerCaptor.capture())
}
}

@Test
fun `verify retry on ValidationException`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
if (attemptCount <= 2) {
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
} else {
CompletableFuture.completedFuture(mock())
}
}
}

amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})

assertThat(attemptCount).isEqualTo(3)
}

@Test
fun `verify retry gives up after max attempts`(): Unit = runBlocking {
var attemptCount = 0
streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(3)
assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
}

@Test
fun `verify no retry on non-retryable exception`(): Unit = runBlocking {
var attemptCount = 0

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
attemptCount++
CompletableFuture.runAsync {
throw IllegalArgumentException("Non-retryable error")
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(attemptCount).isEqualTo(1)
assertThat(thrown)
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Non-retryable error")
}

@Test
fun `verify backoff timing between retries`(): Unit = runBlocking {
var lastAttemptTime = 0L
var minBackoffObserved = Long.MAX_VALUE
var maxBackoffObserved = 0L

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
val currentTime = System.currentTimeMillis()
if (lastAttemptTime > 0) {
val backoffTime = currentTime - lastAttemptTime
minBackoffObserved = minOf(minBackoffObserved, backoffTime)
maxBackoffObserved = maxOf(maxBackoffObserved, backoffTime)
}
lastAttemptTime = currentTime

CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive("test-id", ExportIntent.TRANSFORMATION, null, {}, {})
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(minBackoffObserved).isGreaterThanOrEqualTo(100)
assertThat(maxBackoffObserved).isLessThanOrEqualTo(10000)
}

@Test
fun `verify onError callback is called with final exception`(): Unit = runBlocking {
var errorCaught: Exception? = null

streamingBearerClient = mockClientManagerRule.create<CodeWhispererStreamingAsyncClient>().stub {
on {
exportResultArchive(any<ExportResultArchiveRequest>(), any<ExportResultArchiveResponseHandler>())
} doAnswer {
CompletableFuture.runAsync {
throw VALIDATION_EXCEPTION
}
}
}

val thrown = catchCoroutineException {
amazonQStreamingClient.exportResultArchive(
"test-id",
ExportIntent.TRANSFORMATION,
null,
{ errorCaught = it },
{}
)
}

assertThat(thrown)
.isInstanceOf(ValidationException::class.java)
.hasMessage("Resource validation failed")
assertThat(errorCaught).isEqualTo(VALIDATION_EXCEPTION)
}

private suspend fun catchCoroutineException(block: suspend () -> Unit): Throwable {
try {
block()
error("Expected exception was not thrown")
} catch (e: Throwable) {
return e
}
}

companion object {
private val VALIDATION_EXCEPTION = ValidationException.builder()
.message("Resource validation failed")
.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import software.amazon.awssdk.utils.IoUtils
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.jetbrains.core.AwsClientManager
import software.aws.toolkits.jetbrains.services.amazonq.clients.RetryableOperation
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.APPLICATION_ZIP
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.AWS_KMS
import software.aws.toolkits.jetbrains.services.codewhisperer.codescan.CodeWhispererCodeScanSession.Companion.CONTENT_MD5
Expand Down Expand Up @@ -208,41 +209,3 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant
else -> message("testgen.message.failed")
}
}

class RetryableOperation<T> {
private var attempts = 0
private var currentDelay = INITIAL_DELAY
private var lastException: Exception? = null

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean,
errorHandler: (Exception, Int) -> Nothing,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
lastException = e

attempts++
if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) {
Thread.sleep(currentDelay)
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
continue
}

errorHandler(e, attempts)
}
}

// This line should never be reached due to errorHandler throwing exception
throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,26 @@ package software.aws.toolkits.jetbrains.services.amazonq.clients
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import kotlinx.coroutines.delay
import kotlinx.coroutines.future.await
import software.amazon.awssdk.core.exception.RetryableException
import software.amazon.awssdk.core.exception.SdkException
import software.amazon.awssdk.services.codewhispererstreaming.CodeWhispererStreamingAsyncClient
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportResultArchiveResponseHandler
import software.amazon.awssdk.services.codewhispererstreaming.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererstreaming.model.ValidationException
import software.aws.toolkits.core.utils.getLogger
import software.aws.toolkits.core.utils.warn
import software.aws.toolkits.jetbrains.core.awsClient
import software.aws.toolkits.jetbrains.core.credentials.ToolkitConnectionManager
import software.aws.toolkits.jetbrains.core.credentials.pinning.QConnection
import java.time.Instant
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicReference
import javax.naming.ServiceUnavailableException
import kotlin.random.Random

@Service(Service.Level.PROJECT)
class AmazonQStreamingClient(private val project: Project) {
Expand Down Expand Up @@ -54,30 +62,46 @@ class AmazonQStreamingClient(private val project: Project) {
val checksum = AtomicReference("")

try {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
RetryableOperation<Unit>().executeSuspend(
operation = {
val result = streamingBearerClient().exportResultArchive(
{
it.exportId(exportId)
it.exportIntent(exportIntent)
it.exportContext(exportContext)
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
)
result.await()
},
ExportResultArchiveResponseHandler.builder().subscriber(
ExportResultArchiveResponseHandler.Visitor.builder()
.onBinaryMetadataEvent {
checksum.set(it.contentChecksum())
}.onBinaryPayloadEvent {
val payloadBytes = it.bytes().asByteArray()
byteBufferList.add(payloadBytes)
}.onDefault {
LOG.warn { "Received unknown payload stream: $it" }
}
.build()
)
.build()
isRetryable = { e ->
when (e) {
is ValidationException,
is ThrottlingException,
is ServiceUnavailableException,
is SdkException,
is TimeoutException,
-> true
else -> false
}
},
errorHandler = { e, attempts ->
onError(e)
throw e
}
)
result.await()
} catch (e: Exception) {
onError(e)
throw e
} finally {
onStreamingFinished(startTime)
}
Expand All @@ -91,3 +115,60 @@ class AmazonQStreamingClient(private val project: Project) {
fun getInstance(project: Project) = project.service<AmazonQStreamingClient>()
}
}

class RetryableOperation<T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to its own file

private var attempts = 0
private var currentDelay = INITIAL_DELAY

private fun getJitteredDelay(): Long {
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
return (currentDelay * (0.5 + Random.nextDouble(0.5))).toLong()
}

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: ((Exception, Int) -> Nothing)? = null,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
attempts++
if (attempts >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
errorHandler?.invoke(e, attempts) ?: throw e
}

Thread.sleep(getJitteredDelay())
}
}

throw RuntimeException("Unexpected state after $attempts attempts")
}
Comment on lines +128 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: ((Exception, Int) -> Nothing)? = null,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
attempts++
if (attempts >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
errorHandler?.invoke(e, attempts) ?: throw e
}
Thread.sleep(getJitteredDelay())
}
}
throw RuntimeException("Unexpected state after $attempts attempts")
}
fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: ((Exception, Int) -> Nothing)? = null,
): T = runBlocking {
executeSuspend(operation, isRetryable, errorHandler)
}

?


suspend fun executeSuspend(
operation: suspend () -> T,
isRetryable: (Exception) -> Boolean = { it is RetryableException },
errorHandler: (suspend (Exception, Int) -> Nothing)? = null,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
attempts++
if (attempts >= MAX_RETRY_ATTEMPTS || !isRetryable(e)) {
errorHandler?.invoke(e, attempts) ?: throw e
}
delay(getJitteredDelay())
}
}

throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L
private const val MAX_BACKOFF = 10000L
private const val MAX_RETRY_ATTEMPTS = 3
}
}
Loading