Skip to content

Commit

Permalink
Added MultiPart Upload to Transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
durner committed Jul 3, 2023
1 parent 5c35ea8 commit b705e36
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 30 deletions.
4 changes: 3 additions & 1 deletion include/cloud/aws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class AWS : public Provider {
Settings _settings;
/// The secret
std::unique_ptr<Secret> _secret;
/// The multipart upload size
uint64_t _multipartUploadSize = 128ull << 20;

public:
/// Get instance details
Expand Down Expand Up @@ -104,7 +106,7 @@ class AWS : public Provider {
/// Get the settings
[[nodiscard]] inline Settings getSettings() { return _settings; }
/// Allows multipart upload if size > 0
[[nodiscard]] uint64_t multipartUploadSize() const override { return 128ull << 20; }
[[nodiscard]] uint64_t multipartUploadSize() const override { return _multipartUploadSize; }

/// Builds the http request for downloading a blob or listing the directory
[[nodiscard]] std::unique_ptr<utils::DataVector<uint8_t>> getRequest(const std::string& filePath, const std::pair<uint64_t, uint64_t>& range) const override;
Expand Down
2 changes: 2 additions & 0 deletions include/cloud/minio.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class MinIO : public AWS {
[[nodiscard]] std::string getAddress() const override;
/// Get the instance details
[[nodiscard]] Provider::Instance getInstanceDetails(network::TaskedSendReceiver& sendReceiver) override;
/// Set the upload split size
constexpr void setMultipartUploadSize(uint64_t size) { _multipartUploadSize = size; }

friend Provider;
};
Expand Down
126 changes: 120 additions & 6 deletions include/network/transaction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "cloud/provider.hpp"
#include "network/message_result.hpp"
#include "network/original_message.hpp"
#include <atomic>
#include <cassert>
#include <memory>
#include <span>
Expand Down Expand Up @@ -32,20 +33,48 @@ class Transaction {
class Iterator;
class ConstIterator;

protected:
struct MultipartUpload {
/// The uploadId
std::string uploadId;
/// The eTags
std::vector<std::string> eTags;
/// The number of outstanding part requests
std::atomic<int> outstanding;

/// The constructor
explicit MultipartUpload(int parts) : eTags(parts), outstanding(parts) {}
/// Copy constructor
MultipartUpload(MultipartUpload& other) = delete;
/// Move constructor
MultipartUpload(MultipartUpload&& other) noexcept : uploadId(std::move(other.uploadId)), eTags(std::move(other.eTags)), outstanding(other.outstanding.load()) {}
/// Copy assignment
MultipartUpload& operator=(MultipartUpload other) = delete;
};

protected:
/// The provider
const cloud::Provider* provider;
/// The message typedef
using message_vector_type = std::vector<std::unique_ptr<network::OriginalMessage>>;
/// The message
message_vector_type messages;
/// The message send counter
uint64_t messageCounter;
/// Multipart uploads
std::vector<MultipartUpload> multipartUploads;

public:
/// Helper function to build callback messages
template <typename Callback, typename... Arguments>
std::unique_ptr<OriginalCallbackMessage<Callback>> makeCallbackMessage(Callback&& c, Arguments&&... args) {
return std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(c), std::forward<Arguments>(args)...);
}

public:
/// The constructor
Transaction() = default;
Transaction() : messages(), messageCounter() {}
/// The explicit constructor with the provider
explicit Transaction(const cloud::Provider* provider) : provider(provider), messages() {}
explicit Transaction(const cloud::Provider* provider) : provider(provider), messages(), messageCounter() {}

/// Set the provider
constexpr void setProvider(const cloud::Provider* provider) { this->provider = provider; }
Expand All @@ -67,13 +96,15 @@ class Transaction {
template <typename Callback>
inline void getObjectRequest(Callback&& callback, const std::string& remotePath, std::pair<uint64_t, uint64_t> range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback&&>(callback), provider->getRequest(remotePath, range), provider->getAddress(), provider->getPort(), result, capacity, traceId);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), provider->getRequest(remotePath, range), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

/// Build a new put request for synchronous calls
inline void putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
if (provider->multipartUploadSize() && size > provider->multipartUploadSize())
return putObjectRequestMultiPart(remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalMessage>(provider->putRequest(remotePath, object), provider->getAddress(), provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
Expand All @@ -84,8 +115,10 @@ class Transaction {
template <typename Callback>
inline void putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
if (provider->multipartUploadSize() && size > provider->multipartUploadSize())
return putObjectRequestMultiPart(std::forward<Callback>(callback), remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback&&>(callback), provider->putRequest(remotePath, object), provider->getAddress(), provider->getPort(), result, capacity, traceId);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), provider->putRequest(remotePath, object), provider->getAddress(), provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
messages.push_back(std::move(originalMsg));
}
Expand All @@ -101,10 +134,91 @@ class Transaction {
template <typename Callback>
inline void deleteObjectRequest(Callback&& callback, const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback&&>(callback), provider->deleteRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), provider->deleteRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

private:
/// Build a new put request for synchronous calls
inline void putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto splitSize = provider->multipartUploadSize();
auto parts = (size / splitSize) + ((size % splitSize) ? 1u : 0u);
multipartUploads.emplace_back(parts);
auto position = multipartUploads.size() - 1;

auto uploadMessages = [position, parts, data, remotePath, traceId, splitSize, size, this](network::MessageResult& result) {
if (!result.success())
return;

multipartUploads[position].uploadId = provider->getUploadId(result.getResult());
auto offset = 0ull;
for (auto i = 1ull; i <= parts; i++) {
auto finishMultipart = [position, remotePath, traceId, i, this](network::MessageResult& result) {
// TODO: requires abort handling
if (!result.success())
return;

multipartUploads[position].eTags[i - 1] = provider->getETag(std::string_view(reinterpret_cast<const char*>(result.getData()), result.getOffset()));
if (multipartUploads[position].outstanding.fetch_sub(1) == 1) {
auto originalMsg = std::make_unique<network::OriginalMessage>(provider->completeMultiPartRequest(remotePath, multipartUploads[position].uploadId, multipartUploads[position].eTags), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
messages.push_back(std::move(originalMsg));
}
};
auto partSize = (i != parts) ? splitSize : size - offset;
auto object = std::string_view(data + offset, partSize);
auto originalMsg = makeCallbackMessage(std::move(finishMultipart), provider->putRequestGeneric(remotePath, object, i, multipartUploads[position].uploadId), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data + offset), partSize);
messages.push_back(std::move(originalMsg));
offset += partSize;
}
};

auto originalMsg = makeCallbackMessage(std::move(uploadMessages), provider->createMultiPartRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

/// Build a new put request with callback
template <typename Callback>
inline void putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(provider);
auto splitSize = provider->multipartUploadSize();
auto parts = (size / splitSize) + ((size % splitSize) ? 1u : 0u);
multipartUploads.emplace_back(parts);
auto position = multipartUploads.size() - 1;

auto uploadMessages = [&callback, position, parts, data, remotePath, traceId, splitSize, size, this](network::MessageResult& result) {
if (!result.success())
return;

multipartUploads[position].uploadId = provider->getUploadId(result.getResult());
auto offset = 0ull;
for (auto i = 1ull; i <= parts; i++) {
auto finishMultipart = [&callback, position, remotePath, traceId, i, this](network::MessageResult& result) {
// TODO: requires abort handling
if (!result.success())
return;

multipartUploads[position].eTags[i - 1] = provider->getETag(std::string_view(reinterpret_cast<const char*>(result.getData()), result.getOffset()));
if (multipartUploads[position].outstanding.fetch_sub(1) == 1) {
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), provider->completeMultiPartRequest(remotePath, multipartUploads[position].uploadId, multipartUploads[position].eTags), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
messages.push_back(std::move(originalMsg));
}
};
auto partSize = (i != parts) ? splitSize : size - offset;
auto object = std::string_view(data + offset, partSize);
auto originalMsg = makeCallbackMessage(std::move(finishMultipart), provider->putRequestGeneric(remotePath, object, i, multipartUploads[position].uploadId), provider->getAddress(), provider->getPort(), nullptr, 0, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data + offset), partSize);
messages.push_back(std::move(originalMsg));
offset += partSize;
}
};

auto originalMsg = makeCallbackMessage(std::move(uploadMessages), provider->createMultiPartRequest(remotePath), provider->getAddress(), provider->getPort(), result, capacity, traceId);
messages.push_back(std::move(originalMsg));
}

public:
/// The iterator
using iterator = Iterator;
/// The const iterator
Expand Down
9 changes: 4 additions & 5 deletions src/cloud/aws.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::putRequestGeneric(const string& file

// Is it a multipart upload?
if (part) {
request.path += "?partNumber=" + to_string(part) + "&uploadId=";
request.path += uploadId;
request.queries.emplace("partNumber", to_string(part));
request.queries.emplace("uploadId", uploadId);
}

request.bodyLength = object.size();
Expand Down Expand Up @@ -320,7 +320,7 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::createMultiPartRequest(const string&
request.path = "/" + filePath;
else
request.path = "/" + _settings.bucket + "/" + filePath;
request.path += "?uploads";
request.queries.emplace("uploads", "");
request.bodyData = nullptr;
request.bodyLength = 0;
request.headers.emplace("Host", getAddress());
Expand Down Expand Up @@ -365,9 +365,8 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::completeMultiPartRequest(const strin
request.path = "/" + filePath;
else
request.path = "/" + _settings.bucket + "/" + filePath;
request.path += "&uploadId=";
request.path += uploadId;

request.queries.emplace("uploadId", uploadId);
request.bodyData = nullptr;
request.bodyLength = 0;
request.headers.emplace("Host", getAddress());
Expand Down
9 changes: 4 additions & 5 deletions src/cloud/gcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::putRequestGeneric(const string& file

// Is it a multipart upload?
if (part) {
request.path += "?partNumber=" + to_string(part) + "&uploadId=";
request.path += uploadId;
request.queries.emplace("partNumber", to_string(part));
request.queries.emplace("uploadId", uploadId);
}

request.bodyData = reinterpret_cast<const uint8_t*>(object.data());
Expand Down Expand Up @@ -171,7 +171,7 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::createMultiPartRequest(const string&
request.method = "POST";
request.type = "HTTP/1.1";
request.path = "/" + filePath;
request.path += "?uploads";
request.queries.emplace("uploads", "");
request.bodyData = nullptr;
request.bodyLength = 0;

Expand Down Expand Up @@ -208,8 +208,7 @@ unique_ptr<utils::DataVector<uint8_t>> GCP::completeMultiPartRequest(const strin
request.method = "POST";
request.type = "HTTP/1.1";
request.path = "/" + filePath;
request.path += "&uploadId=";
request.path += uploadId;
request.queries.emplace("uploadId", uploadId);
request.bodyData = nullptr;
request.bodyLength = 0;

Expand Down
18 changes: 10 additions & 8 deletions src/network/transaction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ using namespace std;
void Transaction::processSync(TaskedSendReceiver& sendReceiver)
// Processes the request messages
{
// send the original request message
for (auto& msg : messages) {
sendReceiver.sendSync(msg.get());
}
do {
// send the original request message
for (; messageCounter < messages.size(); messageCounter++) {
sendReceiver.sendSync(messages[messageCounter].get());
}

// do the download work
sendReceiver.processSync();
// do the download work
sendReceiver.processSync();
} while (messages.size() != messageCounter);
}
//---------------------------------------------------------------------------
void Transaction::processAsync(TaskedSendReceiver& sendReceiver)
// Sends the request messages to the task group
{
// send the original request message
for (auto& msg : messages) {
while (!sendReceiver.send(msg.get())) {}
for (; messageCounter < messages.size(); messageCounter++) {
while (!sendReceiver.send(messages[messageCounter].get())) {}
}
}
//---------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit b705e36

Please sign in to comment.