Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer if custom content provider stream #5841

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-ee5927f.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Buffer input data from ContentStreamProvider in cases where content length is known."
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@
@NotThreadSafe
public final class BufferingContentStreamProvider implements ContentStreamProvider {
private final ContentStreamProvider delegate;
private InputStream bufferedStream;
private final Long expectedLength;
private BufferStream bufferedStream;

private byte[] bufferedStreamData;
private int count;

public BufferingContentStreamProvider(ContentStreamProvider delegate) {
public BufferingContentStreamProvider(ContentStreamProvider delegate, Long expectedLength) {
this.delegate = delegate;
this.expectedLength = expectedLength;
}

@Override
public InputStream newStream() {
if (bufferedStreamData != null) {
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
return new ByteArrayStream(bufferedStreamData, 0, this.count);
}

if (bufferedStream == null) {
Expand All @@ -59,36 +61,57 @@ public InputStream newStream() {
return bufferedStream;
}

private class BufferStream extends BufferedInputStream {
class ByteArrayStream extends ByteArrayInputStream {

ByteArrayStream(byte[] buf, int offset, int length) {
super(buf, offset, length);
}

@Override
public void close() throws IOException {
super.close();
bufferedStream.close();
}
}

class BufferStream extends BufferedInputStream {
BufferStream(InputStream in) {
super(in);
}

@Override
public synchronized int read() throws IOException {
int read = super.read();
if (read < 0) {
saveBuffer();
}
return read;
public byte[] getBuf() {
return this.buf;
}

public int getCount() {
return this.count;
}

@Override
public synchronized int read(byte[] b, int off, int len) throws IOException {
int read = super.read(b, off, len);
if (read < 0) {
public void close() throws IOException {
// We only want to close the underlying stream if we're confident all its data is buffered. In some cases, the
// stream might be closed before we read everything, and we want to avoid closing in these cases if the request
// body is being reused.
if (!hasExpectedLength() || expectedLengthReached()) {
saveBuffer();
super.close();
}
dagnir marked this conversation as resolved.
Show resolved Hide resolved
return read;
}
}

private void saveBuffer() {
if (bufferedStreamData == null) {
IoUtils.closeQuietlyV2(in, null);
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
BufferingContentStreamProvider.this.count = this.count;
}
private void saveBuffer() {
if (bufferedStreamData == null) {
this.bufferedStreamData = bufferedStream.getBuf();
this.count = bufferedStream.getCount();
}
}

private boolean expectedLengthReached() {
return bufferedStream.getCount() >= expectedLength;
}

private boolean hasExpectedLength() {
return this.expectedLength != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@ public static RequestBody fromFile(File file) {
public static RequestBody fromInputStream(InputStream inputStream, long contentLength) {
IoUtils.markStreamWithMaxReadLimit(inputStream);
InputStream nonCloseable = nonCloseableInputStream(inputStream);
return fromContentProvider(() -> {
ContentStreamProvider provider = () -> {
if (nonCloseable.markSupported()) {
invokeSafely(nonCloseable::reset);
}
return nonCloseable;
}, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
};
return new RequestBody(provider, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
}

/**
Expand Down Expand Up @@ -209,6 +210,14 @@ public static RequestBody empty() {

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
* <p>
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
* methods</a>.
*
* @param provider The content provider.
* @param contentLength The content length.
Expand All @@ -217,17 +226,14 @@ public static RequestBody empty() {
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
dagnir marked this conversation as resolved.
Show resolved Hide resolved
return new RequestBody(provider, contentLength, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
}

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
* the content length.
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown.
* <p>
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
* content length. This will cause increased memory usage.
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
Expand All @@ -240,7 +246,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, null), null, mimeType);
}

/**
Expand All @@ -254,7 +260,7 @@ private static RequestBody fromBytesDirect(byte[] bytes) {
* Creates a {@link RequestBody} using the specified bytes (without copying).
*/
private static RequestBody fromBytesDirect(byte[] bytes, String mimetype) {
return fromContentProvider(() -> new ByteArrayInputStream(bytes), bytes.length, mimetype);
return new RequestBody(() -> new ByteArrayInputStream(bytes), (long) bytes.length, mimetype);
}

private static InputStream nonCloseableInputStream(InputStream inputStream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
package software.amazon.awssdk.core.internal.sync;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand Down Expand Up @@ -110,17 +109,89 @@ void newStream_closeClosesDelegateStream() throws IOException {
}

@Test
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
public void newStream_delegateStreamClosedOnBufferingStreamClose() throws IOException {
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));

requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
InputStream stream = requestBody.contentStreamProvider().newStream();
IoUtils.drainInputStream(stream);
stream.close();

Mockito.verify(delegateStream).close();
}

@Test
public void newStream_lengthKnown_readUpToLengthThenClosed_newStreamUsesBufferedData() throws IOException {
ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
requestBody = RequestBody.fromContentProvider(() -> stream, TEST_DATA.length, "text/plain");

int totalRead = 0;
int read;

InputStream stream1 = requestBody.contentStreamProvider().newStream();
do {
read = stream1.read();
if (read != -1) {
++totalRead;
}
} while (read != -1);

assertThat(totalRead).isEqualTo(TEST_DATA.length);

stream1.close();

assertThat(requestBody.contentStreamProvider().newStream())
.isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);
}

@Test
public void newStream_lengthKnown_partialRead_close_doesNotBufferData() throws IOException {
// We need a large buffer because BufferedInputStream buffers data in chunks. If the buffer is small enough, a single
// read() on the BufferedInputStream might actually buffer all the delegate's data.

byte[] newData = new byte[16536];
new Random().nextBytes(newData);
ByteArrayInputStream stream = new ByteArrayInputStream(newData);
requestBody = RequestBody.fromContentProvider(() -> stream, newData.length, "text/plain");

InputStream stream1 = requestBody.contentStreamProvider().newStream();
int read = stream1.read();
assertThat(read).isNotEqualTo(-1);

stream1.close();

InputStream stream2 = requestBody.contentStreamProvider().newStream();
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.BufferStream.class);

assertThat(getCrc32(stream2)).isEqualTo(getCrc32(new ByteArrayInputStream(newData)));
}

@Test
public void newStream_bufferedDataStreamPartialRead_closed_bufferedDataIsNotReplaced() throws IOException {
byte[] newData = new byte[16536];
new Random().nextBytes(newData);
String newDataChecksum = getCrc32(new ByteArrayInputStream(newData));

ByteArrayInputStream stream = new ByteArrayInputStream(newData);

requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
InputStream stream1 = requestBody.contentStreamProvider().newStream();
IoUtils.drainInputStream(stream1);
stream1.close();

InputStream stream2 = requestBody.contentStreamProvider().newStream();
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);

int read = stream2.read();
assertThat(read).isNotEqualTo(-1);

stream2.close();

InputStream stream3 = requestBody.contentStreamProvider().newStream();
assertThat(stream3).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verifyNoMoreInteractions(delegateStream);
assertThat(getCrc32(stream3)).isEqualTo(newDataChecksum);
}

private static String getCrc32(InputStream inputStream) {
Expand Down
Loading