Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 3279 - Upgrade ECS SDK in clients module #3438

Merged
merged 7 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions java/clients/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@
<artifactId>ecr</artifactId>
<version>${aws-java-sdk-v2.version}</version>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-ecs</artifactId>
<version>${aws-java-sdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>cloudwatchevents</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

package sleeper.clients.deploy;

import com.amazonaws.services.ecs.AmazonECS;
import com.amazonaws.services.ecs.model.ListTasksRequest;
import com.amazonaws.services.ecs.model.StopTaskRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.lambda.LambdaClient;

import sleeper.core.properties.instance.InstanceProperties;
Expand All @@ -34,7 +32,7 @@
public class RestartTasks {

private static final Logger LOGGER = LoggerFactory.getLogger(RestartTasks.class);
private final AmazonECS ecs;
private final EcsClient ecs;
private final LambdaClient lambda;
private final InstanceProperties properties;
private final boolean skip;
Expand Down Expand Up @@ -70,20 +68,20 @@ private void restartTasks(InstanceProperty clusterProperty, InstanceProperty lam
}

private void stopTasksInCluster(String cluster) {
ecs.listTasks(new ListTasksRequest().withCluster(cluster)).getTaskArns()
.forEach(task -> ecs.stopTask(new StopTaskRequest().withTask(task).withCluster(cluster)));
ecs.listTasks(builder -> builder.cluster(cluster)).taskArns()
.forEach(task -> ecs.stopTask(builder -> builder.cluster(cluster).task(task)));
}

public static final class Builder {
private AmazonECS ecs;
private EcsClient ecs;
private LambdaClient lambda;
private InstanceProperties properties;
private boolean skip;

private Builder() {
}

public Builder ecs(AmazonECS ecs) {
public Builder ecs(EcsClient ecs) {
this.ecs = ecs;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
*/
package sleeper.clients.teardown;

import com.amazonaws.services.ecs.AmazonECS;
import com.amazonaws.services.ecs.model.ListTasksRequest;
import com.amazonaws.services.ecs.model.ListTasksResult;
import com.amazonaws.services.ecs.model.StopTaskRequest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.cloudwatchevents.CloudWatchEventsClient;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.emr.EmrClient;
import software.amazon.awssdk.services.emr.model.ListClustersResponse;
import software.amazon.awssdk.services.emrserverless.EmrServerlessClient;
Expand All @@ -32,6 +29,7 @@
import sleeper.core.properties.SleeperProperty;
import sleeper.core.properties.instance.InstanceProperties;
import sleeper.core.util.StaticRateLimit;
import sleeper.core.util.ThreadSleep;

import java.util.List;
import java.util.function.Consumer;
Expand All @@ -46,24 +44,27 @@ public class ShutdownSystemProcesses {
private static final Logger LOGGER = LoggerFactory.getLogger(ShutdownSystemProcesses.class);

private final CloudWatchEventsClient cloudWatch;
private final AmazonECS ecs;
private final EcsClient ecs;
private final EmrClient emrClient;
private final EmrServerlessClient emrServerlessClient;
private final StaticRateLimit<ListClustersResponse> listActiveClustersLimit;
private final ThreadSleep threadSleep;

public ShutdownSystemProcesses(TearDownClients clients) {
this(clients.getCloudWatch(), clients.getEcs(), clients.getEmr(), clients.getEmrServerless(), EmrUtils.LIST_ACTIVE_CLUSTERS_LIMIT);
this(clients.getCloudWatch(), clients.getEcs(), clients.getEmr(), clients.getEmrServerless(), EmrUtils.LIST_ACTIVE_CLUSTERS_LIMIT, Thread::sleep);
}

public ShutdownSystemProcesses(
CloudWatchEventsClient cloudWatch, AmazonECS ecs,
CloudWatchEventsClient cloudWatch, EcsClient ecs,
EmrClient emrClient, EmrServerlessClient emrServerlessClient,
StaticRateLimit<ListClustersResponse> listActiveClustersLimit) {
StaticRateLimit<ListClustersResponse> listActiveClustersLimit,
ThreadSleep threadSleep) {
this.cloudWatch = cloudWatch;
this.ecs = ecs;
this.emrClient = emrClient;
this.emrServerlessClient = emrServerlessClient;
this.listActiveClustersLimit = listActiveClustersLimit;
this.threadSleep = threadSleep;
}

public void shutdown(InstanceProperties instanceProperties, List<String> extraECSClusters) throws InterruptedException {
Expand All @@ -82,40 +83,36 @@ private void stopECSTasks(InstanceProperties instanceProperties, List<String> ex
}

private void stopEMRClusters(InstanceProperties properties) throws InterruptedException {
new TerminateEMRClusters(emrClient, properties.get(ID), listActiveClustersLimit).run();
new TerminateEMRClusters(emrClient, properties.get(ID), listActiveClustersLimit, threadSleep).run();
}

private void stopEMRServerlessApplication(InstanceProperties properties) throws InterruptedException {
new TerminateEMRServerlessApplications(emrServerlessClient, properties).run();
}

public static <T extends SleeperProperty> void stopTasks(AmazonECS ecs, SleeperProperties<T> properties, T property) {
public static <T extends SleeperProperty> void stopTasks(EcsClient ecs, SleeperProperties<T> properties, T property) {
if (!properties.isSet(property)) {
return;
}
stopTasks(ecs, properties.get(property));
}

private static void stopTasks(AmazonECS ecs, String clusterName) {
private static void stopTasks(EcsClient ecs, String clusterName) {
LOGGER.info("Stopping tasks for ECS cluster {}", clusterName);
forEachTaskArn(ecs, clusterName, taskArn -> {
// Rate limit for ECS StopTask is 100 burst, 40 sustained:
// https://docs.aws.amazon.com/AmazonECS/latest/APIReference/request-throttling.html
sleepForSustainedRatePerSecond(30);
ecs.stopTask(new StopTaskRequest().withCluster(clusterName).withTask(taskArn)
.withReason("Cleaning up before cdk destroy"));
ecs.stopTask(builder -> builder.cluster(clusterName).task(taskArn)
.reason("Cleaning up before cdk destroy"));
});
}

private static void forEachTaskArn(AmazonECS ecs, String clusterName, Consumer<String> consumer) {
String nextToken = null;
do {
ListTasksResult result = ecs.listTasks(
new ListTasksRequest().withCluster(clusterName).withNextToken(nextToken));

LOGGER.info("Found {} tasks", result.getTaskArns().size());
result.getTaskArns().forEach(consumer);
nextToken = result.getNextToken();
} while (nextToken != null);
private static void forEachTaskArn(EcsClient ecs, String clusterName, Consumer<String> consumer) {
ecs.listTasksPaginator(builder -> builder.cluster(clusterName))
.stream()
.peek(response -> LOGGER.info("Found {} tasks", response.taskArns().size()))
.flatMap(response -> response.taskArns().stream())
.forEach(consumer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

package sleeper.clients.teardown;

import com.amazonaws.services.ecs.AmazonECS;
import com.amazonaws.services.ecs.AmazonECSClientBuilder;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import software.amazon.awssdk.services.cloudformation.CloudFormationClient;
import software.amazon.awssdk.services.cloudwatchevents.CloudWatchEventsClient;
import software.amazon.awssdk.services.ecr.EcrClient;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.emr.EmrClient;
import software.amazon.awssdk.services.emrserverless.EmrServerlessClient;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -35,7 +34,7 @@ public class TearDownClients {
private final AmazonS3 s3;
private final S3Client s3v2;
private final CloudWatchEventsClient cloudWatch;
private final AmazonECS ecs;
private final EcsClient ecs;
private final EcrClient ecr;
private final EmrClient emr;
private final EmrServerlessClient emrServerless;
Expand All @@ -54,10 +53,10 @@ private TearDownClients(Builder builder) {

public static void withDefaults(TearDownOperation operation) throws IOException, InterruptedException {
AmazonS3 s3Client = AmazonS3ClientBuilder.defaultClient();
AmazonECS ecsClient = AmazonECSClientBuilder.defaultClient();
try (S3Client s3v2Client = S3Client.create();
CloudWatchEventsClient cloudWatchClient = CloudWatchEventsClient.create();
EcrClient ecrClient = EcrClient.create();
EcsClient ecsClient = EcsClient.create();
EmrClient emrClient = EmrClient.create();
EmrServerlessClient emrServerless = EmrServerlessClient.create();
CloudFormationClient cloudFormationClient = CloudFormationClient.create()) {
Expand All @@ -74,7 +73,6 @@ public static void withDefaults(TearDownOperation operation) throws IOException,
operation.tearDown(clients);
} finally {
s3Client.shutdown();
ecsClient.shutdown();
}
}

Expand All @@ -94,7 +92,7 @@ public CloudWatchEventsClient getCloudWatch() {
return cloudWatch;
}

public AmazonECS getEcs() {
public EcsClient getEcs() {
return ecs;
}

Expand All @@ -118,7 +116,7 @@ public static final class Builder {
private AmazonS3 s3;
private S3Client s3v2;
private CloudWatchEventsClient cloudWatch;
private AmazonECS ecs;
private EcsClient ecs;
private EcrClient ecr;
private EmrClient emr;
private EmrServerlessClient emrServerless;
Expand All @@ -142,7 +140,7 @@ public Builder cloudWatch(CloudWatchEventsClient cloudWatch) {
return this;
}

public Builder ecs(AmazonECS ecs) {
public Builder ecs(EcsClient ecs) {
this.ecs = ecs;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import sleeper.core.util.PollWithRetries;
import sleeper.core.util.StaticRateLimit;
import sleeper.core.util.ThreadSleep;

import java.time.Duration;
import java.util.List;
Expand All @@ -42,11 +43,13 @@ public class TerminateEMRClusters {
private final EmrClient emrClient;
private final String clusterPrefix;
private final StaticRateLimit<ListClustersResponse> listActiveClustersLimit;
private final ThreadSleep threadSleep;

public TerminateEMRClusters(EmrClient emrClient, String instanceId, StaticRateLimit<ListClustersResponse> listActiveClustersLimit) {
public TerminateEMRClusters(EmrClient emrClient, String instanceId, StaticRateLimit<ListClustersResponse> listActiveClustersLimit, ThreadSleep threadSleep) {
this.emrClient = emrClient;
this.clusterPrefix = "sleeper-" + instanceId + "-";
this.listActiveClustersLimit = listActiveClustersLimit;
this.threadSleep = threadSleep;
}

public void run() throws InterruptedException {
Expand Down Expand Up @@ -75,7 +78,7 @@ private void terminateClusters(List<String> clusters) {
LOGGER.info("Terminated {} clusters out of {}", endIndex, clusters.size());
// Sustained limit of 0.5 calls per second
// See https://docs.aws.amazon.com/general/latest/gr/emr.html
sleepForSustainedRatePerSecond(0.2);
sleepForSustainedRatePerSecond(0.2, threadSleep);
}
}

Expand Down Expand Up @@ -103,7 +106,7 @@ public static void main(String[] args) throws InterruptedException {
String instanceId = args[0];

try (EmrClient emrClient = EmrClient.create()) {
TerminateEMRClusters terminateClusters = new TerminateEMRClusters(emrClient, instanceId, StaticRateLimit.none());
TerminateEMRClusters terminateClusters = new TerminateEMRClusters(emrClient, instanceId, StaticRateLimit.none(), Thread::sleep);
terminateClusters.run();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED;
import static sleeper.clients.testutil.ClientWiremockTestHelper.OPERATION_HEADER;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockCloudWatchClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEcsClientV1;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEcsClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEmrClient;
import static sleeper.clients.testutil.ClientWiremockTestHelper.wiremockEmrServerlessClient;
import static sleeper.clients.testutil.WiremockCloudWatchTestHelper.anyRequestedForCloudWatchEvents;
Expand Down Expand Up @@ -85,6 +85,7 @@
import static sleeper.core.properties.instance.CdkDefinedInstanceProperty.TABLE_METRICS_RULE;
import static sleeper.core.properties.instance.CommonProperty.ID;
import static sleeper.core.properties.testutils.InstancePropertiesTestHelper.createTestInstanceProperties;
import static sleeper.core.util.ThreadSleepTestHelper.noWaits;

@WireMockTest
class ShutdownSystemProcessesIT {
Expand All @@ -94,8 +95,8 @@ class ShutdownSystemProcessesIT {

@BeforeEach
void setUp(WireMockRuntimeInfo runtimeInfo) {
shutdown = new ShutdownSystemProcesses(wiremockCloudWatchClient(runtimeInfo), wiremockEcsClientV1(runtimeInfo),
wiremockEmrClient(runtimeInfo), wiremockEmrServerlessClient(runtimeInfo), StaticRateLimit.none());
shutdown = new ShutdownSystemProcesses(wiremockCloudWatchClient(runtimeInfo), wiremockEcsClient(runtimeInfo),
wiremockEmrClient(runtimeInfo), wiremockEmrServerlessClient(runtimeInfo), StaticRateLimit.none(), noWaits());
}

private void shutdown() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@
*/
package sleeper.clients.testutil;

import com.amazonaws.services.ecs.AmazonECS;
import com.amazonaws.services.ecs.AmazonECSClientBuilder;
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo;
import software.amazon.awssdk.services.cloudformation.CloudFormationClient;
import software.amazon.awssdk.services.cloudwatchevents.CloudWatchEventsClient;
import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient;
import software.amazon.awssdk.services.ecr.EcrClient;
import software.amazon.awssdk.services.ecs.EcsClient;
import software.amazon.awssdk.services.emr.EmrClient;
import software.amazon.awssdk.services.emrserverless.EmrServerlessClient;

import static sleeper.task.common.WiremockTestHelper.wiremockAwsV2Client;
import static sleeper.task.common.WiremockTestHelper.wiremockCredentialsProvider;
import static sleeper.task.common.WiremockTestHelper.wiremockEndpointConfiguration;

public class ClientWiremockTestHelper {

Expand All @@ -36,11 +33,8 @@ public class ClientWiremockTestHelper {
private ClientWiremockTestHelper() {
}

public static AmazonECS wiremockEcsClientV1(WireMockRuntimeInfo runtimeInfo) {
return AmazonECSClientBuilder.standard()
.withEndpointConfiguration(wiremockEndpointConfiguration(runtimeInfo))
.withCredentials(wiremockCredentialsProvider())
.build();
public static EcsClient wiremockEcsClient(WireMockRuntimeInfo runtimeInfo) {
return wiremockAwsV2Client(runtimeInfo, EcsClient.builder());
}

public static EcrClient wiremockEcrClient(WireMockRuntimeInfo runtimeInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import sleeper.core.statestore.StateStore;
import sleeper.core.statestore.StateStoreProvider;
import sleeper.core.statestore.testutils.FixedStateStoreProvider;
import sleeper.core.util.ExponentialBackoffWithJitter.Waiter;
import sleeper.core.util.ExponentialBackoffWithJitterTestHelper.WaitAction;
import sleeper.core.util.ThreadSleep;
import sleeper.core.util.ThreadSleepTestHelper;

import java.time.Duration;
import java.time.Instant;
Expand All @@ -65,8 +65,6 @@
import static sleeper.core.schema.SchemaTestHelper.schemaWithKey;
import static sleeper.core.statestore.AssignJobIdRequest.assignJobOnPartitionToFiles;
import static sleeper.core.statestore.testutils.StateStoreTestHelper.inMemoryStateStoreWithSinglePartition;
import static sleeper.core.util.ExponentialBackoffWithJitterTestHelper.recordWaits;
import static sleeper.core.util.ExponentialBackoffWithJitterTestHelper.withActionAfterWait;

public class CompactionTaskTestBase {
protected static final String DEFAULT_TABLE_ID = "test-table-id";
Expand All @@ -89,7 +87,7 @@ public class CompactionTaskTestBase {
protected final List<Duration> sleeps = new ArrayList<>();
protected final List<CompactionJobCommitRequest> commitRequestsOnQueue = new ArrayList<>();
protected final List<Duration> foundWaitsForFileAssignment = new ArrayList<>();
private Waiter waiterForFileAssignment = recordWaits(foundWaitsForFileAssignment);
private ThreadSleep waiterForFileAssignment = ThreadSleepTestHelper.recordWaits(foundWaitsForFileAssignment);

@BeforeEach
void setUpBase() {
Expand Down Expand Up @@ -231,8 +229,8 @@ protected void send(CompactionJob job) {
jobsOnQueue.add(job);
}

protected void actionAfterWaitForFileAssignment(WaitAction action) throws Exception {
waiterForFileAssignment = withActionAfterWait(waiterForFileAssignment, action);
protected void actionAfterWaitForFileAssignment(ThreadSleepTestHelper.WaitAction action) throws Exception {
waiterForFileAssignment = ThreadSleepTestHelper.withActionAfterWait(waiterForFileAssignment, action);
}

private MessageReceiver pollQueue() {
Expand Down
Loading
Loading