Skip to content

Commit

Permalink
fix(amazonq): for /test adding backoff and retry for payload upload A…
Browse files Browse the repository at this point in the history
…PIs. (#5310)

* Adding backoff and retry for payload upload APIs.

Co-authored-by: Laxman Reddy <[email protected]>
  • Loading branch information
ashishrp-aws and laileni-aws authored Feb 3, 2025
1 parent 02ec702 commit 5470418
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import software.amazon.awssdk.core.exception.SdkServiceException
import software.amazon.awssdk.services.codewhispererruntime.model.GetTestGenerationResponse
import software.amazon.awssdk.services.codewhispererruntime.model.Range
import software.amazon.awssdk.services.codewhispererruntime.model.StartTestGenerationResponse
import software.amazon.awssdk.services.codewhispererruntime.model.TargetCode
import software.amazon.awssdk.services.codewhispererruntime.model.TestGenerationJobStatus
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportContext
import software.amazon.awssdk.services.codewhispererstreaming.model.ExportIntent
import software.aws.toolkits.core.utils.Waiters.waitUntil
import software.aws.toolkits.core.utils.debug
import software.aws.toolkits.core.utils.error
import software.aws.toolkits.core.utils.getLogger
Expand Down Expand Up @@ -58,6 +58,7 @@ import java.io.ByteArrayOutputStream
import java.io.File
import java.io.IOException
import java.nio.file.Paths
import java.time.Duration
import java.time.Instant
import java.util.concurrent.atomic.AtomicBoolean
import java.util.zip.ZipInputStream
Expand Down Expand Up @@ -109,29 +110,38 @@ class CodeWhispererUTGChatManager(val project: Project, private val cs: Coroutin

// 2nd API call: StartTestGeneration
val startTestGenerationResponse = try {
startTestGeneration(
uploadId = createUploadUrlResponse.uploadId(),
targetCode = listOf(
TargetCode.builder()
.relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
.targetLineRangeList(
if (selectionRange != null) {
listOf(
selectionRange
var response: StartTestGenerationResponse? = null

waitUntil(
succeedOn = { response?.sdkHttpResponse()?.statusCode() == 200 },
maxDuration = Duration.ofSeconds(1), // 1 second timeout
) {
try {
response = startTestGeneration(
uploadId = createUploadUrlResponse.uploadId(),
targetCode = listOf(
TargetCode.builder()
.relativeTargetPath(codeTestResponseContext.currentFileRelativePath.toString())
.targetLineRangeList(
if (selectionRange != null) {
listOf(selectionRange)
} else {
emptyList()
}
)
} else {
emptyList()
}
)
.build()
),
userInput = prompt
)
} catch (e: Exception) {
val statusCode = when {
e is SdkServiceException -> e.statusCode()
else -> 400
.build()
),
userInput = prompt
)
delay(200)
response?.testGenerationJob() != null
} catch (e: Exception) {
throw e
}
}

response ?: throw RuntimeException("Failed to start test generation")
} catch (e: Exception) {
LOG.error(e) { "Unexpected error while creating test generation job" }
val errorMessage = getTelemetryErrorMessage(e, CodeWhispererConstants.FeatureName.TEST_GENERATION)
throw CodeTestException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import software.amazon.awssdk.services.codewhispererruntime.model.CodeAnalysisUp
import software.amazon.awssdk.services.codewhispererruntime.model.CodeFixUploadContext
import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlRequest
import software.amazon.awssdk.services.codewhispererruntime.model.CreateUploadUrlResponse
import software.amazon.awssdk.services.codewhispererruntime.model.InternalServerException
import software.amazon.awssdk.services.codewhispererruntime.model.ThrottlingException
import software.amazon.awssdk.services.codewhispererruntime.model.UploadContext
import software.amazon.awssdk.services.codewhispererruntime.model.UploadIntent
import software.amazon.awssdk.utils.IoUtils
Expand Down Expand Up @@ -82,40 +84,50 @@ class CodeWhispererZipUploadManager(private val project: Project) {
requestHeaders: Map<String, String>?,
featureUseCase: CodeWhispererConstants.FeatureName,
) {
try {
val uploadIdJson = """{"uploadId":"$uploadId"}"""
HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner {
if (requestHeaders.isNullOrEmpty()) {
it.setRequestProperty(CONTENT_MD5, md5)
it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP)
it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS)
if (kmsArn?.isNotEmpty() == true) {
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn)
}
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray()))
} else {
requestHeaders.forEach { entry ->
it.setRequestProperty(entry.key, entry.value)
RetryableOperation<Unit>().execute(
operation = {
val uploadIdJson = """{"uploadId":"$uploadId"}"""
HttpRequests.put(url, "application/zip").userAgent(AwsClientManager.getUserAgent()).tuner {
if (requestHeaders.isNullOrEmpty()) {
it.setRequestProperty(CONTENT_MD5, md5)
it.setRequestProperty(CONTENT_TYPE, APPLICATION_ZIP)
it.setRequestProperty(SERVER_SIDE_ENCRYPTION, AWS_KMS)
if (kmsArn?.isNotEmpty() == true) {
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_AWS_KMS_KEY_ID, kmsArn)
}
it.setRequestProperty(SERVER_SIDE_ENCRYPTION_CONTEXT, Base64.getEncoder().encodeToString(uploadIdJson.toByteArray()))
} else {
requestHeaders.forEach { entry ->
it.setRequestProperty(entry.key, entry.value)
}
}
}.connect {
val connection = it.connection as HttpURLConnection
connection.setFixedLengthStreamingMode(fileToUpload.length())
IoUtils.copy(fileToUpload.inputStream(), connection.outputStream)
}
},
isRetryable = { e ->
when (e) {
is IOException -> true
else -> false
}
},
errorHandler = { e, attempts ->
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW ->
codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION ->
throw CodeTestException(
"UploadTestArtifactToS3Error: $errorMessage",
"UploadTestArtifactToS3Error",
message("testgen.error.generic_technical_error_message")
)
else -> throw RuntimeException("$errorMessage (after $attempts attempts)")
}
}.connect {
val connection = it.connection as HttpURLConnection
connection.setFixedLengthStreamingMode(fileToUpload.length())
IoUtils.copy(fileToUpload.inputStream(), connection.outputStream)
}
} catch (e: Exception) {
LOG.debug { "$featureUseCase: Artifact failed to upload in the S3 bucket: ${e.message}" }
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException(
"UploadTestArtifactToS3Error: $errorMessage",
"UploadTestArtifactToS3Error",
message("testgen.error.generic_technical_error_message")
)
else -> throw RuntimeException(errorMessage) // Adding else for safety check
}
}
)
}

fun createUploadUrl(
Expand All @@ -124,35 +136,44 @@ class CodeWhispererZipUploadManager(private val project: Project) {
uploadTaskType: CodeWhispererConstants.UploadTaskType,
taskName: String,
featureUseCase: CodeWhispererConstants.FeatureName,
): CreateUploadUrlResponse = try {
CodeWhispererClientAdaptor.getInstance(project).createUploadUrl(
CreateUploadUrlRequest.builder()
.contentMd5(md5Content)
.artifactType(artifactType)
.uploadIntent(getUploadIntent(uploadTaskType))
.uploadContext(
// For UTG we don't need uploadContext but sending else case as UploadContext
if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) {
UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build())
} else {
UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build())
}
)
.build()
)
} catch (e: Exception) {
LOG.debug { "$featureUseCase: Create Upload URL failed: ${e.message}" }
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW -> codeScanServerException("CreateUploadUrlException: $errorMessage")
CodeWhispererConstants.FeatureName.TEST_GENERATION -> throw CodeTestException(
"CreateUploadUrlError: $errorMessage",
"CreateUploadUrlError",
message("testgen.error.generic_technical_error_message")
): CreateUploadUrlResponse = RetryableOperation<CreateUploadUrlResponse>().execute(
operation = {
CodeWhispererClientAdaptor.getInstance(project).createUploadUrl(
CreateUploadUrlRequest.builder()
.contentMd5(md5Content)
.artifactType(artifactType)
.uploadIntent(getUploadIntent(uploadTaskType))
.uploadContext(
// For UTG we don't need uploadContext but sending else case as UploadContext
if (uploadTaskType == CodeWhispererConstants.UploadTaskType.CODE_FIX) {
UploadContext.fromCodeFixUploadContext(CodeFixUploadContext.builder().codeFixName(taskName).build())
} else {
UploadContext.fromCodeAnalysisUploadContext(CodeAnalysisUploadContext.builder().codeScanName(taskName).build())
}
)
.build()
)
else -> throw RuntimeException(errorMessage) // Adding else for safety check
},
isRetryable = { e ->
e is ThrottlingException || e is InternalServerException
},
errorHandler = { e, attempts ->
val errorMessage = getTelemetryErrorMessage(e, featureUseCase)
when (featureUseCase) {
CodeWhispererConstants.FeatureName.CODE_REVIEW ->
codeScanServerException("CreateUploadUrlException after $attempts attempts: $errorMessage")

CodeWhispererConstants.FeatureName.TEST_GENERATION ->
throw CodeTestException(
"CreateUploadUrlError after $attempts attempts: $errorMessage",
"CreateUploadUrlError",
message("testgen.error.generic_technical_error_message")
)

else -> throw RuntimeException("$errorMessage (after $attempts attempts)")
}
}
}
)

private fun getUploadIntent(uploadTaskType: CodeWhispererConstants.UploadTaskType): UploadIntent = when (uploadTaskType) {
CodeWhispererConstants.UploadTaskType.SCAN_FILE -> UploadIntent.AUTOMATIC_FILE_SECURITY_SCAN
Expand Down Expand Up @@ -187,3 +208,41 @@ fun getTelemetryErrorMessage(e: Exception, featureUseCase: CodeWhispererConstant
else -> message("testgen.message.failed")
}
}

class RetryableOperation<T> {
private var attempts = 0
private var currentDelay = INITIAL_DELAY
private var lastException: Exception? = null

fun execute(
operation: () -> T,
isRetryable: (Exception) -> Boolean,
errorHandler: (Exception, Int) -> Nothing,
): T {
while (attempts < MAX_RETRY_ATTEMPTS) {
try {
return operation()
} catch (e: Exception) {
lastException = e

attempts++
if (attempts < MAX_RETRY_ATTEMPTS && isRetryable(e)) {
Thread.sleep(currentDelay)
currentDelay = (currentDelay * 2).coerceAtMost(MAX_BACKOFF)
continue
}

errorHandler(e, attempts)
}
}

// This line should never be reached due to errorHandler throwing exception
throw RuntimeException("Unexpected state after $attempts attempts")
}

companion object {
private const val INITIAL_DELAY = 100L // milliseconds
private const val MAX_BACKOFF = 10000L // milliseconds
private const val MAX_RETRY_ATTEMPTS = 3
}
}

0 comments on commit 5470418

Please sign in to comment.