Skip to content

Commit

Permalink
delegates cross region bucket location determination to SDK (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
markjschreiber authored Dec 2, 2024
1 parent 558f42c commit b1b20e3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 220 deletions.
154 changes: 5 additions & 149 deletions src/main/java/software/amazon/nio/spi/s3/S3ClientProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,22 @@

package software.amazon.nio.spi.s3;

import static java.util.concurrent.TimeUnit.MINUTES;
import static software.amazon.nio.spi.s3.util.TimeOutUtils.logAndGenerateExceptionOnTimeOut;

import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import java.net.URI;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3CrtAsyncClientBuilder;
import software.amazon.awssdk.services.s3.model.HeadBucketResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.nio.spi.s3.config.S3NioSpiConfiguration;

/**
* Factory/builder class that creates async S3 clients. It also provides
* default clients that can be used for basic operations (e.g. bucket discovery).
* Creates async S3 clients used by this library.
*/
public class S3ClientProvider {

private static final Logger logger = LoggerFactory.getLogger(S3ClientProvider.class);

/**
* Default asynchronous client using the "<a href="https://s3.us-east-1.amazonaws.com">...</a>" endpoint
*/
@Deprecated
protected S3AsyncClient universalClient;

/**
Expand All @@ -49,10 +35,6 @@ public class S3ClientProvider {
S3AsyncClient.crtBuilder()
.crossRegionAccessEnabled(true);

private final Cache<String, String> bucketRegionCache = Caffeine.newBuilder()
.maximumSize(16)
.expireAfterWrite(Duration.ofMinutes(30))
.build();

private final Cache<String, CacheableS3Client> bucketClientCache = Caffeine.newBuilder()
.maximumSize(4)
Expand All @@ -61,35 +43,12 @@ public class S3ClientProvider {

public S3ClientProvider(S3NioSpiConfiguration c) {
this.configuration = (c == null) ? new S3NioSpiConfiguration() : c;
this.universalClient = S3AsyncClient.builder()
.endpointOverride(URI.create("https://s3.us-east-1.amazonaws.com"))
.crossRegionAccessEnabled(true)
.region(Region.US_EAST_1)
.build();
}

public void asyncClientBuilder(final S3CrtAsyncClientBuilder builder) {
asyncClientBuilder = builder;
}

/**
* This method returns a universal client bound to the us-east-1 region
* that can be used by certain S3 operations for discovery such as getBucketLocation.
*
* @return an S3AsyncClient bound to us-east-1
*/
S3AsyncClient universalClient() {
return universalClient;
}

/**
* Sets the fallback client used to make {@code S3AsyncClient#getBucketLocation()} calls. Typically, this would
* only be set for testing purposes to use a {@code Mock} or {@code Spy} class.
* @param client the client to be used for getBucketLocation calls.
*/
void universalClient(S3AsyncClient client) {
this.universalClient = client;
}

/**
* Generates a sync client for the named bucket using a client configured by the default region configuration chain.
Expand All @@ -98,108 +57,15 @@ void universalClient(S3AsyncClient client) {
* @return an S3 client appropriate for the region of the named bucket
*/
protected S3AsyncClient generateClient(String bucket) {
try {
return generateClient(bucket, S3AsyncClient.create());
} catch (ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}

/**
* Generate an async client for the named bucket using a provided client to
* determine the location of the named client
*
* @param bucketName the name of the bucket to make the client for
* @param locationClient the client used to determine the location of the
* named bucket, recommend using {@code S3ClientProvider#UNIVERSAL_CLIENT}
* @return an S3 client appropriate for the region of the named bucket
*/
S3AsyncClient generateClient(String bucketName, S3AsyncClient locationClient)
throws ExecutionException, InterruptedException {
logger.debug("generating client for bucket: '{}'", bucketName);

String bucketLocation = null;
if (configuration.endpointUri() == null) {
// we try to locate a bucket only if no endpoint is provided, which means we are dealing with AWS S3 buckets
bucketLocation = getBucketLocation(bucketName, locationClient);

if (bucketLocation == null) {
// if here, no S3 nor other client has been created yet, and we do not
// have a location; we'll let it figure out from the profile region
logger.warn("Unable to determine the region of bucket: '{}'. Generating a client for the profile region.",
bucketName);
}
}

var client = bucketClientCache.getIfPresent(bucketName);
var client = bucketClientCache.getIfPresent(bucket);
if (client != null && !client.isClosed()) {
return client;
} else {
if (client != null && client.isClosed()) {
bucketClientCache.invalidate(bucketName); // remove the closed client from the cache
}
String r = Optional.ofNullable(bucketLocation).orElse(configuration.getRegion());
return bucketClientCache.get(bucketName, b -> new CacheableS3Client(configureCrtClientForRegion(r)));
}
}

private String getBucketLocation(String bucketName, S3AsyncClient locationClient)
throws ExecutionException, InterruptedException {

if (bucketRegionCache.getIfPresent(bucketName) != null) {
return bucketRegionCache.getIfPresent(bucketName);
}

logger.debug("checking if the bucket is in the same region as the providedClient using HeadBucket");
try (var client = locationClient) {
final HeadBucketResponse response = client
.headBucket(builder -> builder.bucket(bucketName))
.get(configuration.getTimeoutLow(), MINUTES);
bucketRegionCache.put(bucketName, response.bucketRegion());
return response.bucketRegion();

} catch (TimeoutException e) {
throw logAndGenerateExceptionOnTimeOut(
logger,
"generateClient",
configuration.getTimeoutLow(),
MINUTES);
} catch (Throwable t) {

if (t instanceof ExecutionException &&
t.getCause() instanceof S3Exception &&
((S3Exception) t.getCause()).statusCode() == 301) { // you got a redirect, the region should be in the header
logger.debug("HeadBucket was unsuccessful, redirect received, attempting to extract x-amz-bucket-region header");
S3Exception s3e = (S3Exception) t.getCause();
final var matchingHeaders = s3e.awsErrorDetails().sdkHttpResponse().matchingHeaders("x-amz-bucket-region");
if (matchingHeaders != null && !matchingHeaders.isEmpty()) {
bucketRegionCache.put(bucketName, matchingHeaders.get(0));
return matchingHeaders.get(0);
}
} else if (t instanceof ExecutionException &&
t.getCause() instanceof S3Exception &&
((S3Exception) t.getCause()).statusCode() == 403) { // HeadBucket was forbidden
logger.debug("HeadBucket forbidden. Attempting a call to GetBucketLocation using the UNIVERSAL_CLIENT");
try {
String location = universalClient.getBucketLocation(builder -> builder.bucket(bucketName))
.get(configuration.getTimeoutLow(), MINUTES).locationConstraintAsString();
bucketRegionCache.put(bucketName, location);
} catch (TimeoutException e) {
throw logAndGenerateExceptionOnTimeOut(
logger,
"generateClient",
configuration.getTimeoutLow(),
MINUTES);
}
} else {
// didn't handle the exception - rethrow it
throw t;
bucketClientCache.invalidate(bucket); // remove the closed client from the cache
}
return bucketClientCache.get(bucket, b -> new CacheableS3Client(configureCrtClient().build()));
}
return "";
}

S3CrtAsyncClientBuilder configureCrtClient() {
Expand All @@ -216,14 +82,4 @@ S3CrtAsyncClientBuilder configureCrtClient() {
return asyncClientBuilder.forcePathStyle(configuration.getForcePathStyle());
}

private S3AsyncClient configureCrtClientForRegion(String regionName) {
var region = getRegionFromRegionName(regionName);
logger.debug("bucket region is: '{}'", region);
return configureCrtClient().region(region).build();
}

private static Region getRegionFromRegionName(String regionName) {
return (regionName == null || regionName.isBlank()) ? null : Region.of(regionName);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ public FixedS3ClientProvider(S3AsyncClient client) {
this.client = client;
}

@Override
S3AsyncClient universalClient() {
return client;
}

@Override
protected S3AsyncClient generateClient(String bucketName) {
return client;
Expand Down
76 changes: 10 additions & 66 deletions src/test/java/software/amazon/nio/spi/s3/S3ClientProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,21 @@
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static software.amazon.nio.spi.s3.S3Matchers.anyConsumer;

import java.net.URI;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.GetBucketLocationResponse;
import software.amazon.awssdk.services.s3.model.HeadBucketResponse;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.nio.spi.s3.config.S3NioSpiConfiguration;

@ExtendWith(MockitoExtension.class)
public class S3ClientProviderTest {

@Mock
S3AsyncClient mockClient; //client used to determine bucket location

S3ClientProvider provider;

@BeforeEach
Expand All @@ -50,87 +37,46 @@ public void initialization() {

assertNotNull(s3ClientProvider.configuration);

S3AsyncClient t = s3ClientProvider.universalClient();
assertNotNull(t);

var config = new S3NioSpiConfiguration();
assertSame(config, new S3ClientProvider(config).configuration);
}

@Test
public void testGenerateAsyncClientWithNoErrors() throws ExecutionException, InterruptedException {
when(mockClient.headBucket(anyConsumer()))
.thenReturn(CompletableFuture.completedFuture(
HeadBucketResponse.builder().bucketRegion("us-west-2").build()));
final var s3Client = provider.generateClient("test-bucket", mockClient);
public void testGenerateAsyncClientWithNoErrors() {
final var s3Client = provider.generateClient("test-bucket");
assertNotNull(s3Client);
}

@Test
public void testGenerateClientIsCacheableClass() throws Exception {
when(mockClient.headBucket(anyConsumer()))
.thenReturn(CompletableFuture.completedFuture(
HeadBucketResponse.builder().bucketRegion("us-west-2").build()));
final var s3Client = provider.generateClient("test-bucket", mockClient);
public void testGenerateClientIsCacheableClass() {
final var s3Client = provider.generateClient("test-bucket");
assertInstanceOf(CacheableS3Client.class, s3Client);
}

@Test
public void testGenerateClientCachesClients() throws Exception {
when(mockClient.headBucket(anyConsumer()))
.thenReturn(CompletableFuture.completedFuture(
HeadBucketResponse.builder().bucketRegion("us-west-2").build()));
final var s3Client = provider.generateClient("test-bucket", mockClient);
final var s3Client2 = provider.generateClient("test-bucket", mockClient);
public void testGenerateClientCachesClients() {
final var s3Client = provider.generateClient("test-bucket");
final var s3Client2 = provider.generateClient("test-bucket");
assertSame(s3Client, s3Client2);
}

@Test
public void testClosedClientIsNotReused() throws ExecutionException, InterruptedException {
when(mockClient.headBucket(anyConsumer()))
.thenReturn(CompletableFuture.completedFuture(
HeadBucketResponse.builder().bucketRegion("us-west-2").build()));
public void testClosedClientIsNotReused() {

final var s3Client = provider.generateClient("test-bucket", mockClient);
final var s3Client = provider.generateClient("test-bucket");
assertNotNull(s3Client);

// now close the client
s3Client.close();

// now generate a new client with the same bucket name
final var s3Client2 = provider.generateClient("test-bucket", mockClient);
final var s3Client2 = provider.generateClient("test-bucket");
assertNotNull(s3Client2);

// assert it is not the closed client
assertNotSame(s3Client, s3Client2);
}

@Test
public void testGenerateAsyncClientWith403Response() throws ExecutionException, InterruptedException {
// when you get a forbidden response from HeadBucket
when(mockClient.headBucket(anyConsumer())).thenReturn(
CompletableFuture.failedFuture(S3Exception.builder().statusCode(403).build())
);

// you should fall back to a get bucket location attempt from the universal client
var mockUniversalClient = mock(S3AsyncClient.class);
provider.universalClient(mockUniversalClient);
when(mockUniversalClient.getBucketLocation(anyConsumer())).thenReturn(CompletableFuture.completedFuture(
GetBucketLocationResponse.builder()
.locationConstraint("us-west-2")
.build()
));

// which should get you a client
final var s3Client = provider.generateClient("test-bucket", mockClient);
assertNotNull(s3Client);

final var inOrder = inOrder(mockClient, mockUniversalClient);
inOrder.verify(mockClient).headBucket(anyConsumer());
inOrder.verify(mockUniversalClient).getBucketLocation(anyConsumer());
inOrder.verifyNoMoreInteractions();
}

@Test
public void generateAsyncClientByEndpointBucketCredentials() {
// GIVEN
Expand All @@ -144,7 +90,6 @@ public void generateAsyncClientByEndpointBucketCredentials() {

// THEN
verify(BUILDER, times(1)).endpointOverride(URI.create("https://endpoint1:1010"));
verify(BUILDER, times(1)).region(null);

// GIVEN
BUILDER = spy(S3AsyncClient.crtBuilder());
Expand All @@ -156,6 +101,5 @@ public void generateAsyncClientByEndpointBucketCredentials() {

// THEN
verify(BUILDER, times(1)).endpointOverride(URI.create("https://endpoint2:2020"));
verify(BUILDER, times(1)).region(null);
}
}

0 comments on commit b1b20e3

Please sign in to comment.