diff --git a/src/main/java/com/aws/greengrass/deployment/DeploymentDocumentDownloader.java b/src/main/java/com/aws/greengrass/deployment/DeploymentDocumentDownloader.java index 74838f693b..f216bbba66 100644 --- a/src/main/java/com/aws/greengrass/deployment/DeploymentDocumentDownloader.java +++ b/src/main/java/com/aws/greengrass/deployment/DeploymentDocumentDownloader.java @@ -12,6 +12,7 @@ import com.aws.greengrass.deployment.exceptions.DeploymentTaskFailureException; import com.aws.greengrass.deployment.exceptions.DeviceConfigurationException; import com.aws.greengrass.deployment.exceptions.InvalidRequestException; +import com.aws.greengrass.deployment.exceptions.RetryableClientErrorException; import com.aws.greengrass.deployment.exceptions.RetryableDeploymentDocumentDownloadException; import com.aws.greengrass.deployment.exceptions.RetryableServerErrorException; import com.aws.greengrass.deployment.model.DeploymentDocument; @@ -47,6 +48,7 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.util.Arrays; +import java.util.Collections; import java.util.Optional; import javax.inject.Inject; @@ -56,16 +58,13 @@ public class DeploymentDocumentDownloader { private static final Logger logger = LogManager.getLogger(DeploymentDocumentDownloader.class); private static final long MAX_DEPLOYMENT_DOCUMENT_SIZE_BYTES = 10 * ONE_MB; + private static final int MAX_CLIENT_ERROR_RETRY_COUNT = 10; private final GreengrassServiceClientFactory greengrassServiceClientFactory; private final HttpClientProvider httpClientProvider; private final DeviceConfiguration deviceConfiguration; @Setter(AccessLevel.PACKAGE) @Getter(AccessLevel.PACKAGE) - private RetryUtils.RetryConfig clientExceptionRetryConfig = - RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofMinutes(1)) - .maxRetryInterval(Duration.ofMinutes(1)).maxAttempt(Integer.MAX_VALUE) - .retryableExceptions(Arrays.asList(RetryableDeploymentDocumentDownloadException.class, - DeviceConfigurationException.class, RetryableServerErrorException.class)).build(); + private RetryUtils.DifferentiatedRetryConfig clientExceptionRetryConfig; /** * Constructor. @@ -81,6 +80,22 @@ public DeploymentDocumentDownloader(GreengrassServiceClientFactory greengrassSer this.greengrassServiceClientFactory = greengrassServiceClientFactory; this.deviceConfiguration = deviceConfiguration; this.httpClientProvider = httpClientProvider; + + RetryUtils.RetryConfig infiniteRetryConfig = + RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofMinutes(1)) + .maxRetryInterval(Duration.ofMinutes(1)).maxAttempt(Integer.MAX_VALUE) + .retryableExceptions(Arrays.asList(RetryableDeploymentDocumentDownloadException.class, + DeviceConfigurationException.class, RetryableServerErrorException.class)).build(); + + RetryUtils.RetryConfig finiteRetryConfig = + RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofMinutes(1)) + .maxRetryInterval(Duration.ofMinutes(1)).maxAttempt(MAX_CLIENT_ERROR_RETRY_COUNT) + .retryableExceptions(Collections.singletonList(RetryableClientErrorException.class)).build(); + + + this.clientExceptionRetryConfig = RetryUtils.DifferentiatedRetryConfig.builder() + .retryConfigList(Arrays.asList(infiniteRetryConfig, finiteRetryConfig)) + .build(); } /** @@ -116,7 +131,8 @@ public DeploymentDocument download(String deploymentId) protected String downloadDeploymentDocument(String deploymentId) throws DeploymentTaskFailureException, RetryableDeploymentDocumentDownloadException, - DeviceConfigurationException, HashingAlgorithmUnavailableException, RetryableServerErrorException { + DeviceConfigurationException, HashingAlgorithmUnavailableException, RetryableServerErrorException, + RetryableClientErrorException { // 1. Get url, digest, and algorithm by calling gg data plane GetDeploymentConfigurationResponse response = getDeploymentConfiguration(deploymentId); @@ -168,7 +184,7 @@ private String downloadFromUrl(String deploymentId, String preSignedUrl) private GetDeploymentConfigurationResponse getDeploymentConfiguration(String deploymentId) throws RetryableDeploymentDocumentDownloadException, DeviceConfigurationException, - DeploymentTaskFailureException, RetryableServerErrorException { + DeploymentTaskFailureException, RetryableServerErrorException, RetryableClientErrorException { String thingName = Coerce.toString(deviceConfiguration.getThingName()); GetDeploymentConfigurationRequest getDeploymentConfigurationRequest = GetDeploymentConfigurationRequest.builder().deploymentId(deploymentId).coreDeviceThingName(thingName) @@ -186,7 +202,12 @@ private GetDeploymentConfigurationResponse getDeploymentConfiguration(String dep } catch (GreengrassV2DataException e) { if (RetryUtils.retryErrorCodes(e.statusCode())) { throw new RetryableServerErrorException("Failed with retryable error: " + e.statusCode() - + "while calling getDeploymetnConfiguration", e); + + " while calling getDeploymentConfiguration", e); + } + // also retry on 404s because sometimes querying DDB may fail initially due to its eventual consistency + if (e.statusCode() == HttpStatusCode.NOT_FOUND) { + throw new RetryableClientErrorException("Failed with retryable error: " + e.statusCode() + + " while calling getDeploymentConfiguration", e); } if (e.statusCode() == HttpStatusCode.FORBIDDEN) { throw new DeploymentTaskFailureException( diff --git a/src/main/java/com/aws/greengrass/deployment/exceptions/RetryableClientErrorException.java b/src/main/java/com/aws/greengrass/deployment/exceptions/RetryableClientErrorException.java new file mode 100644 index 0000000000..d426bbb1e8 --- /dev/null +++ b/src/main/java/com/aws/greengrass/deployment/exceptions/RetryableClientErrorException.java @@ -0,0 +1,21 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.aws.greengrass.deployment.exceptions; + +/** + * Exception for handling 4xx deployment failures. + */ +public class RetryableClientErrorException extends DeploymentException { + static final long serialVersionUID = -3387516993124229948L; + + public RetryableClientErrorException(String message) { + super(message); + } + + public RetryableClientErrorException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/aws/greengrass/util/RetryUtils.java b/src/main/java/com/aws/greengrass/util/RetryUtils.java index 1363150853..63ae1fe30d 100644 --- a/src/main/java/com/aws/greengrass/util/RetryUtils.java +++ b/src/main/java/com/aws/greengrass/util/RetryUtils.java @@ -11,7 +11,10 @@ import lombok.Getter; import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Random; public class RetryUtils { @@ -39,49 +42,89 @@ private RetryUtils() { "PMD.AvoidInstanceofChecksInCatchClause"}) public static T runWithRetry(RetryConfig retryConfig, CrashableSupplier task, String taskDescription, Logger logger) throws Exception { - long retryInterval = retryConfig.getInitialRetryInterval().toMillis(); - int attempt = 1; - // if it's not the final attempt, execute and backoff on retryable exceptions - while (attempt < retryConfig.maxAttempt) { + return runWithRetry(DifferentiatedRetryConfig.fromRetryConfig(retryConfig), task, taskDescription, logger); + } + + /** + * Run a task with differentiated retry behaviors. Different maximum retry attempts based on different exception + * types. Stop the retry when interrupted. + * @param differentiatedRetryConfig differentiated retry config + * @param task task to run + * @param taskDescription task description + * @param logger logger + * @param return type + * @return return value + * @throws Exception Exception + */ + @SuppressWarnings({"PMD.SignatureDeclareThrowsException", "PMD.AvoidCatchingGenericException", + "PMD.AvoidInstanceofChecksInCatchClause"}) + public static T runWithRetry(DifferentiatedRetryConfig differentiatedRetryConfig, + CrashableSupplier task, String taskDescription, Logger logger) + throws Exception { + long retryInterval = 0; + long totalAttempts = 0; + long totalMaxAttempts = calculateTotalMaxAttempts(differentiatedRetryConfig); + Map attemptMap = new HashMap<>(); + differentiatedRetryConfig.getRetryConfigList() + .forEach(retryConfig -> attemptMap.put(retryConfig, 1)); + + while (totalAttempts < totalMaxAttempts) { if (Thread.currentThread().isInterrupted()) { throw new InterruptedException(taskDescription + " task is interrupted"); } + try { return task.apply(); } catch (Exception e) { if (e instanceof InterruptedException) { throw e; } - if (retryConfig.retryableExceptions.stream().anyMatch(c -> c.isInstance(e))) { - LogEventBuilder logBuild = logger.atDebug(taskDescription); - // Log first and every LOG_ON_FAILURE_COUNT failed attempt at info so as not to spam logs - // After the initial ramp up period , the task would be retried every 1 min and hence - // the failure will be logged once every 20 minutes. - if (attempt == 1 || attempt % LOG_ON_FAILURE_COUNT == 0) { - logBuild = logger.atInfo(taskDescription); - } - logBuild.kv("task-attempt", attempt).setCause(e) - .log("task failed and will be retried"); - // Backoff with jitter strategy from EqualJitterBackoffStrategy in AWS SDK - Thread.sleep(retryInterval / 2 + RANDOM.nextInt((int) (retryInterval / 2 + 1))); - if (retryInterval < retryConfig.getMaxRetryInterval().toMillis()) { - retryInterval = retryInterval * 2; - } else { - retryInterval = retryConfig.getMaxRetryInterval().toMillis(); + + boolean foundExceptionInMap = false; + for (RetryConfig retryConfig : differentiatedRetryConfig.getRetryConfigList()) { + if (retryConfig.getRetryableExceptions().stream().anyMatch(c -> c.isInstance(e))) { + foundExceptionInMap = true; + // increment attempt count + int attempt = attemptMap.get(retryConfig); + if (attempt >= retryConfig.getMaxAttempt()) { + throw e; + } + attemptMap.put(retryConfig, attempt + 1); + + // log the message + LogEventBuilder logBuilder = attempt == 1 || attempt % LOG_ON_FAILURE_COUNT == 0 + ? logger.atInfo(taskDescription) + : logger.atDebug(taskDescription); + logBuilder.kv("task-attempt", attempt).setCause(e).log("task failed and will be retried"); + + // sleep with back-off + if (retryInterval == 0) { + retryInterval = retryConfig.getInitialRetryInterval().toMillis(); + } + Thread.sleep(retryInterval / 2 + RANDOM.nextInt((int) (retryInterval / 2 + 1))); + retryInterval = Math.min(retryInterval * 2, retryConfig.getMaxRetryInterval().toMillis()); + + // break since exception is found + break; } - attempt++; - } else { + } + if (!foundExceptionInMap) { throw e; } + totalAttempts++; } } - // if it's the final attempt, return directly - if (Thread.currentThread().isInterrupted()) { - throw new InterruptedException(taskDescription + " task is interrupted"); - } return task.apply(); } + // Use long to avoid integer overflow + private static long calculateTotalMaxAttempts(DifferentiatedRetryConfig config) { + return config.getRetryConfigList().stream() + .mapToLong(RetryConfig::getMaxAttempt) + .sum(); + } + + @Builder(toBuilder = true) @Getter public static class RetryConfig { @@ -91,9 +134,35 @@ public static class RetryConfig { Duration maxRetryInterval = Duration.ofMinutes(1L); @Builder.Default int maxAttempt = 10; + // keep compatibility with older versions List retryableExceptions; } + @Builder(toBuilder = true) + @Getter + public static class DifferentiatedRetryConfig { + // map between set of exception classes to retry on and max retry attempt for each set + List retryConfigList; + + /** + * Create a DifferentiatedRetryConfig from RetryConfig. + * @param retryConfig retryConfig + * @return differentiatedRetryConfig + */ + public static DifferentiatedRetryConfig fromRetryConfig(RetryConfig retryConfig) { + return DifferentiatedRetryConfig.builder() + .retryConfigList(Collections.singletonList(retryConfig)) + .build(); + } + + // Used for unit test + public void setInitialRetryIntervalForAll(Duration duration) { + retryConfigList.forEach(retryConfig -> retryConfig.initialRetryInterval = duration); + } + } + + + /** * Check if given error code qualifies for triggering retry mechanism. * diff --git a/src/test/java/com/aws/greengrass/deployment/DeploymentDocumentDownloaderTest.java b/src/test/java/com/aws/greengrass/deployment/DeploymentDocumentDownloaderTest.java index 49648ffc84..e11e8eeeac 100644 --- a/src/test/java/com/aws/greengrass/deployment/DeploymentDocumentDownloaderTest.java +++ b/src/test/java/com/aws/greengrass/deployment/DeploymentDocumentDownloaderTest.java @@ -11,6 +11,7 @@ import com.aws.greengrass.deployment.converter.DeploymentDocumentConverter; import com.aws.greengrass.deployment.exceptions.DeploymentTaskFailureException; import com.aws.greengrass.deployment.exceptions.DeviceConfigurationException; +import com.aws.greengrass.deployment.exceptions.RetryableClientErrorException; import com.aws.greengrass.deployment.exceptions.RetryableDeploymentDocumentDownloadException; import com.aws.greengrass.deployment.exceptions.RetryableServerErrorException; import com.aws.greengrass.deployment.model.DeploymentDocument; @@ -60,7 +61,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; @ExtendWith({GGExtension.class, MockitoExtension.class}) class DeploymentDocumentDownloaderTest { @@ -387,9 +393,10 @@ void GIVEN_download_content_with_invalid_format_WHEN_download_THEN_throws_with_p } @Test - void GIVEN_gg_client_response_500_error_code_WHEN_download_THEN_retry(ExtensionContext context) + void GIVEN_gg_client_response_500_error_and_404_error_WHEN_download_THEN_retry(ExtensionContext context) throws Exception { ignoreExceptionOfType(context, RetryableServerErrorException.class); + ignoreExceptionOfType(context, RetryableClientErrorException.class); when(httpClientProvider.getSdkHttpClient()).thenReturn(httpClient); Path testFcsDeploymentJsonPath = @@ -400,10 +407,13 @@ void GIVEN_gg_client_response_500_error_code_WHEN_download_THEN_retry(ExtensionC String url = "https://www.presigned.com/a.json"; - Exception e = GreengrassV2DataException.builder().statusCode(500).build(); + Exception e1 = GreengrassV2DataException.builder().statusCode(500).build(); + Exception e2 = GreengrassV2DataException.builder().statusCode(404).build(); + // mock gg client when(greengrassV2DataClient.getDeploymentConfiguration(Mockito.any(GetDeploymentConfigurationRequest.class))) - .thenThrow(e) + .thenThrow(e1) + .thenThrow(e2) .thenReturn(GetDeploymentConfigurationResponse.builder().preSignedUrl(url) .integrityCheck(IntegrityCheck.builder().algorithm("SHA-256").digest(expectedDigest).build()) .build()); @@ -416,13 +426,13 @@ void GIVEN_gg_client_response_500_error_code_WHEN_download_THEN_retry(ExtensionC .responseBody(AbortableInputStream.create(Files.newInputStream(testFcsDeploymentJsonPath))) .build()); - downloader.setClientExceptionRetryConfig( - downloader.getClientExceptionRetryConfig().toBuilder().initialRetryInterval(Duration.ZERO).build()); + downloader.getClientExceptionRetryConfig().setInitialRetryIntervalForAll(Duration.ZERO); downloader.download(DEPLOYMENT_ID); // verify - verify(greengrassV2DataClient, times(2)).getDeploymentConfiguration(GetDeploymentConfigurationRequest.builder().deploymentId(DEPLOYMENT_ID).coreDeviceThingName(THING_NAME) - .build()); + verify(greengrassV2DataClient, times(3)).getDeploymentConfiguration( + GetDeploymentConfigurationRequest.builder().deploymentId(DEPLOYMENT_ID).coreDeviceThingName(THING_NAME) + .build()); } @Test void GIVEN_regional_s3_endpoint_in_device_config_WHEN_download_THEN_request_uses_regional_endpoint() diff --git a/src/test/java/com/aws/greengrass/util/RetryUtilsTest.java b/src/test/java/com/aws/greengrass/util/RetryUtilsTest.java index 4c43246799..eb9ca893f6 100644 --- a/src/test/java/com/aws/greengrass/util/RetryUtilsTest.java +++ b/src/test/java/com/aws/greengrass/util/RetryUtilsTest.java @@ -11,7 +11,9 @@ import java.io.IOException; import java.time.Duration; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -22,7 +24,7 @@ class RetryUtilsTest { Logger logger = LogManager.getLogger(this.getClass()).createChild(); RetryUtils.RetryConfig config = RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofSeconds(1)) .maxRetryInterval(Duration.ofSeconds(1)).maxAttempt(Integer.MAX_VALUE).retryableExceptions( - Arrays.asList(IOException.class)).build(); + Collections.singletonList(IOException.class)).build(); @Test void GIVEN_retryableException_WHEN_runWithRetry_THEN_retry() throws Exception { @@ -47,4 +49,34 @@ void GIVEN_nonRetryableException_WHEN_runWithRetry_THEN_throwException() { }, "", logger)); assertEquals(1, invoked.get()); } + + @Test + void GIVEN_differentiatedRetryConfig_WHEN_runWithRetry_THEN_retryDifferently() { + AtomicInteger invoked = new AtomicInteger(0); + List configList = new ArrayList<>(); + + configList.add(RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofSeconds(1)) + .maxRetryInterval(Duration.ofSeconds(1)).maxAttempt(3).retryableExceptions( + Collections.singletonList(IOException.class)).build()); + + configList.add(RetryUtils.RetryConfig.builder().initialRetryInterval(Duration.ofSeconds(1)) + .maxRetryInterval(Duration.ofSeconds(1)).maxAttempt(2).retryableExceptions( + Collections.singletonList(RuntimeException.class)).build()); + + RetryUtils.DifferentiatedRetryConfig config = RetryUtils.DifferentiatedRetryConfig.builder() + .retryConfigList(configList) + .build(); + + assertThrows(RuntimeException.class, () -> RetryUtils.runWithRetry(config, () -> { + // throw IO exception on even number attempts -> 2 times + // throw runtime exception on odd number attempts -> 1 times + // at last it will throw runtime exception out because we only allow 2 max retries + if (invoked.getAndIncrement() % 2 == 0) { + throw new IOException(); + } else { + throw new RuntimeException(); + } + }, "", logger)); + assertEquals(4, invoked.get()); + } }