Skip to content

Commit

Permalink
Added MultiPart Upload to Transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
durner committed Jul 3, 2023
1 parent 5c35ea8 commit 9fc9aed
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 20 deletions.
4 changes: 3 additions & 1 deletion include/cloud/aws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class AWS : public Provider {
Settings _settings;
/// The secret
std::unique_ptr<Secret> _secret;
/// The multipart upload size
uint64_t _multipartUploadSize = 128ull << 20;

public:
/// Get instance details
Expand Down Expand Up @@ -104,7 +106,7 @@ class AWS : public Provider {
/// Get the settings
[[nodiscard]] inline Settings getSettings() { return _settings; }
/// Allows multipart upload if size > 0
[[nodiscard]] uint64_t multipartUploadSize() const override { return 128ull << 20; }
[[nodiscard]] uint64_t multipartUploadSize() const override { return _multipartUploadSize; }

/// Builds the http request for downloading a blob or listing the directory
[[nodiscard]] std::unique_ptr<utils::DataVector<uint8_t>> getRequest(const std::string& filePath, const std::pair<uint64_t, uint64_t>& range) const override;
Expand Down
2 changes: 2 additions & 0 deletions include/cloud/minio.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MinIO : public AWS {
[[nodiscard]] std::string getAddress() const override;
/// Get the instance details
[[nodiscard]] Provider::Instance getInstanceDetails(network::TaskedSendReceiver& sendReceiver) override;
/// Set the upload split size
constexpr void setMultipartUploadSize(uint64_t size) { _multipartUploadSize = size; }

friend Provider;
};
Expand Down
114 changes: 113 additions & 1 deletion include/network/transaction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "cloud/provider.hpp"
#include "network/message_result.hpp"
#include "network/original_message.hpp"
#include <atomic>
#include <cassert>
#include <memory>
#include <span>
Expand Down Expand Up @@ -32,16 +33,42 @@ class Transaction {
class Iterator;
class ConstIterator;

protected:
struct MultipartUpload {
/// The uploadId
std::string uploadId;
/// The eTags
std::vector<std::string> eTags;
/// The number of outstanding part requests
std::atomic<int> outstanding;

/// The constructor
explicit MultipartUpload(int parts) : eTags(parts), outstanding(parts) {}
/// Copy constructor
MultipartUpload(MultipartUpload& other) = delete;
/// Move constructor
MultipartUpload(MultipartUpload&& other) noexcept : uploadId(std::move(other.uploadId)), eTags(std::move(other.eTags)), outstanding(other.outstanding.load()) {}
/// Copy assignment
MultipartUpload& operator=(MultipartUpload other) = delete;
};

protected:
/// The provider
const cloud::Provider* provider;
/// The message typedef
using message_vector_type = std::vector<std::unique_ptr<network::OriginalMessage>>;
/// The message
message_vector_type messages;
/// Multipart uploads
std::vector<MultipartUpload> multipartUploads;

public:
/// Helper function to build callback messages
template <typename Callback, typename... Arguments>
std::unique_ptr<OriginalCallbackMessage<Callback>> makeCallbackMessage(Callback&& c, Arguments&&... args) {
return std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(c), std::forward<Arguments>(args)...);
}

public:
/// The constructor
Transaction() = default;
/// The explicit constructor with the provider
Expand Down Expand Up @@ -74,6 +101,8 @@ class Transaction {
/// Build a new put request for synchronous calls
inline void putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
if (provider->multipartUploadSize() && size > provider->multipartUploadSize())
return putObjectRequestMultiPart(remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalMessage>(provider->putRequest(remotePath, object), provider->getAddress(), provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
Expand All @@ -84,6 +113,8 @@ class Transaction {
template <typename Callback>
inline void putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
if (provider->multipartUploadSize() && size > provider->multipartUploadSize())
return putObjectRequestMultiPart(std::forward<Callback>(callback), remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback&&>(callback), provider->putRequest(remotePath, object), provider->getAddress(), provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
Expand All @@ -105,6 +136,87 @@ class Transaction {
messages.push_back(std::move(originalMsg));
}

private:
/// Build a new put request for synchronous calls
inline void putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto splitSize = provider->multipartUploadSize();
auto parts = (size / splitSize) + ((size % splitSize) ? 1u : 0u);
multipartUploads.emplace_back(parts);
auto position = multipartUploads.size() - 1;

auto uploadMessages = [position, parts, data, remotePath, traceId, splitSize, size, this](network::MessageResult& result) {
if (!result.success())
return;

multipartUploads[position].uploadId = provider->getUploadId(result.getResult());
auto offset = 0ull;
for (auto i = 1ull; i <= parts; i++) {
auto finishMultipart = [position, remotePath, traceId, i, this](network::MessageResult& result) {
// TODO: requires abort handling
if (!result.success())
return;

multipartUploads[position].eTags[i - 1] = provider->getETag(std::string_view(reinterpret_cast<const char*>(result.getData()), result.getOffset()));
if (multipartUploads[position].outstanding.fetch_sub(1) == 1) {
auto originalMsg = std::make_unique<network::OriginalMessage>(provider->completeMultiPartRequest(remotePath, multipartUploads[position].uploadId, multipartUploads[position].eTags), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
messages.push_back(std::move(originalMsg));
}
};
auto partSize = (i != parts) ? splitSize : size - offset;
auto object = std::string_view(data + offset, partSize);
auto originalMsg = makeCallbackMessage(std::move(finishMultipart), provider->putRequestGeneric(remotePath, object, i, multipartUploads[position].uploadId), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data + offset), partSize);
messages.push_back(std::move(originalMsg));
offset += partSize;
}
};

auto originalMsg = makeCallbackMessage(std::move(uploadMessages), provider->createMultiPartRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

/// Build a new put request with callback
template <typename Callback>
inline void putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto splitSize = provider->multipartUploadSize();
auto parts = (size / splitSize) + (size % splitSize) ? 1u : 0u;
multipartUploads.emplace_back(parts);
auto position = multipartUploads.size() - 1;

auto uploadMessages = [&](network::MessageResult& result) {
if (!result.success())
return;

multipartUploads[position].uploadId = provider->getUploadId(result.getResult());
auto offset = 0ull;
for (auto i = 1ull; i <= parts; i++) {
auto finishMultipart = [&](network::MessageResult& result) {
// TODO: requires abort handling
if (!result.success())
return;

multipartUploads[position].eTags[i - 1] = provider->getETag(std::string_view(reinterpret_cast<const char*>(result.getData()), result.getOffset()));
if (multipartUploads[position].outstanding.fetch_sub(1) == 1) {
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback&&>(callback), provider->completeMultiPartRequest(remotePath, multipartUploads[position].uploadId, multipartUploads[position].eTags), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}
};
auto partSize = (i == parts) ? splitSize : size - offset;
auto object = std::string_view(data + offset, partSize);
auto originalMsg = makeCallbackMessage(std::move(finishMultipart), provider->putRequestGeneric(remotePath, object, i, multipartUploads[position].uploadId), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data + offset), partSize);
messages.push_back(std::move(originalMsg));
offset += partSize;
}
};

auto originalMsg = makeCallbackMessage(std::move(uploadMessages), provider->createMultiPartRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

public:
/// The iterator
using iterator = Iterator;
/// The const iterator
Expand Down
9 changes: 4 additions & 5 deletions src/cloud/aws.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::putRequestGeneric(const string& file

// Is it a multipart upload?
if (part) {
request.path += "?partNumber=" + to_string(part) + "&uploadId=";
request.path += uploadId;
request.queries.emplace("partNumber", to_string(part));
request.queries.emplace("uploadId", uploadId);
}

request.bodyLength = object.size();
Expand Down Expand Up @@ -320,7 +320,7 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::createMultiPartRequest(const string&
request.path = "/" + filePath;
else
request.path = "/" + _settings.bucket + "/" + filePath;
request.path += "?uploads";
request.queries.emplace("uploads", "");
request.bodyData = nullptr;
request.bodyLength = 0;
request.headers.emplace("Host", getAddress());
Expand Down Expand Up @@ -365,9 +365,8 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::completeMultiPartRequest(const strin
request.path = "/" + filePath;
else
request.path = "/" + _settings.bucket + "/" + filePath;
request.path += "&uploadId=";
request.path += uploadId;

request.queries.emplace("uploadId", uploadId);
request.bodyData = nullptr;
request.bodyLength = 0;
request.headers.emplace("Host", getAddress());
Expand Down
9 changes: 4 additions & 5 deletions src/cloud/gcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::putRequestGeneric(const string& file

// Is it a multipart upload?
if (part) {
request.path += "?partNumber=" + to_string(part) + "&uploadId=";
request.path += uploadId;
request.queries.emplace("partNumber", to_string(part));
request.queries.emplace("uploadId", uploadId);
}

request.bodyData = reinterpret_cast<const uint8_t*>(object.data());
Expand Down Expand Up @@ -171,7 +171,7 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::createMultiPartRequest(const string&
request.method = "POST";
request.type = "HTTP/1.1";
request.path = "/" + filePath;
request.path += "?uploads";
request.queries.emplace("uploads", "");
request.bodyData = nullptr;
request.bodyLength = 0;

Expand Down Expand Up @@ -208,8 +208,7 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::completeMultiPartRequest(const strin
request.method = "POST";
request.type = "HTTP/1.1";
request.path = "/" + filePath;
request.path += "&uploadId=";
request.path += uploadId;
request.queries.emplace("uploadId", uploadId);
request.bodyData = nullptr;
request.bodyLength = 0;

Expand Down
15 changes: 9 additions & 6 deletions src/network/transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ using namespace std;
void Transaction::processSync(TaskedSendReceiver& sendReceiver)
// Processes the request messages
{
// send the original request message
for (auto& msg : messages) {
sendReceiver.sendSync(msg.get());
}
auto messageCounter = 0ull;
do {
// send the original request message
for (; messageCounter < messages.size(); messageCounter++) {
sendReceiver.sendSync(messages[messageCounter].get());
}

// do the download work
sendReceiver.processSync();
// do the download work
sendReceiver.processSync();
} while (messages.size() != messageCounter);
}
//---------------------------------------------------------------------------
void Transaction::processAsync(TaskedSendReceiver& sendReceiver)
Expand Down
21 changes: 19 additions & 2 deletions test/integration/minio.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include "catch2/single_include/catch2/catch.hpp"
#include "cloud/provider.hpp"
#include "cloud/minio.hpp"
#include "network/tasked_send_receiver.hpp"
#include "network/transaction.hpp"
#include <cstdlib>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
//---------------------------------------------------------------------------
// AnyBlob - Universal Cloud Object Storage Library
// Dominik Durner, 2022
Expand Down Expand Up @@ -71,6 +71,23 @@ TEST_CASE("MinIO Integration") {
REQUIRE(it.success());
}
}
{
// Create the multipart put request
auto minio = static_cast<anyblob::cloud::MinIO*>(provider.get());
minio->setMultipartUploadSize(6ull << 20);
anyblob::network::Transaction putTxn(provider.get());
for (auto i = 0u; i < 2; i++)
putTxn.putObjectRequest(fileName[i], content[i].data(), content[i].size());

// Upload the request synchronously with the scheduler object on this thread
putTxn.processSync(sendReceiver);

// Check the upload
for (const auto& it : putTxn) {
// Sucessful request
REQUIRE(it.success());
}
}
{
// Create the get request
anyblob::network::Transaction getTxn(provider.get());
Expand Down Expand Up @@ -102,7 +119,7 @@ TEST_CASE("MinIO Integration") {
}
}
{
// Create the put request
// Create the delete request
anyblob::network::Transaction deleteTxn(provider.get());
for (auto i = 0u; i < 2; i++)
deleteTxn.deleteObjectRequest(fileName[i]);
Expand Down

0 comments on commit 9fc9aed

Please sign in to comment.