From 6cdd4b1f1d350a03281c45aec3086b66d88e715b Mon Sep 17 00:00:00 2001 From: Dominik Durner Date: Fri, 15 Dec 2023 13:51:36 +0100 Subject: [PATCH] Added verifyKeyRequest method --- include/cloud/aws.hpp | 19 ++++--- include/cloud/provider.hpp | 2 + include/network/transaction.hpp | 62 ++++++++++++++++++----- src/cloud/aws.cpp | 86 +++++++++++++++++++++----------- src/cloud/azure.cpp | 1 + test/integration/minio_async.cpp | 47 ++++++++++++----- 6 files changed, 157 insertions(+), 60 deletions(-) diff --git a/include/cloud/aws.hpp b/include/cloud/aws.hpp index 360ef98..de3809a 100644 --- a/include/cloud/aws.hpp +++ b/include/cloud/aws.hpp @@ -67,13 +67,18 @@ class AWS : public Provider { /// The settings Settings _settings; /// The global secret - std::atomic> _secret; - /// The session secret - std::atomic> _sessionSecret; + std::shared_ptr _globalSecret; + /// The global session secret + std::shared_ptr _globalSessionSecret; /// The multipart upload size uint64_t _multipartUploadSize = 128ull << 20; /// The secret mutex std::mutex _mutex; + /// The thread local secret + thread_local static std::shared_ptr _secret; + /// The session secret + thread_local static std::shared_ptr _sessionSecret; + public: /// Get instance details @@ -96,15 +101,17 @@ class AWS : public Provider { } /// The custom endpoint constructor AWS(const RemoteInfo& info, const std::string& keyId, const std::string& key) : AWS(info) { - _secret = std::make_unique(); + _globalSecret = std::make_unique(); // At init it is fine to simply overwrite - _secret.load()->keyId = keyId; - _secret.load()->secret = key; + _globalSecret->keyId = keyId; + _globalSecret->secret = key; } private: /// Initialize secret void initSecret(network::TaskedSendReceiver& sendReceiver) override; + /// Get a local copy of the global secret + void getSecret() override; /// Builds the secret http request [[nodiscard]] std::unique_ptr> downloadIAMUser() const; /// Builds the secret http request diff --git a/include/cloud/provider.hpp b/include/cloud/provider.hpp index 02329d9..b0bc6e7 100644 --- a/include/cloud/provider.hpp +++ b/include/cloud/provider.hpp @@ -101,6 +101,8 @@ class Provider { /// Initialize secret virtual void initSecret(network::TaskedSendReceiver& /*sendReceiver*/) {} + /// Get a local copy of the global secret + virtual void getSecret() {} public: /// The destructor diff --git a/include/network/transaction.hpp b/include/network/transaction.hpp index c3cb09f..5802e0b 100644 --- a/include/network/transaction.hpp +++ b/include/network/transaction.hpp @@ -70,7 +70,7 @@ class Transaction { }; /// The provider - const cloud::Provider* _provider; + cloud::Provider* _provider; /// The message message_vector_type _messages; @@ -91,81 +91,116 @@ class Transaction { /// The constructor Transaction() : _provider(), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {} /// The explicit constructor with the provider - explicit Transaction(const cloud::Provider* provider) : _provider(provider), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {} + explicit Transaction(cloud::Provider* provider) : _provider(provider), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {} /// Set the provider - constexpr void setProvider(const cloud::Provider* provider) { this->_provider = provider; } + constexpr void setProvider(cloud::Provider* provider) { this->_provider = provider; } /// Sends the request messages to the task group void processAsync(network::TaskedSendReceiverGroup& group); /// Processes the request messages void processSync(TaskedSendReceiver& sendReceiver); + /// Function to ensure fresh keys before creating messages + /// This is needed to ensure valid keys before a message is requested + /// Simply forward a task send receiver the message function and the args of this message + template + bool verifyKeyRequest(TaskedSendReceiver& sendReceiver, Function&& func) { + assert(_provider); + _provider->initSecret(sendReceiver); + return std::forward(func)(); + } + /// Build a new get request for synchronous calls /// Note that the range is [start, end[, [0, 0[ gets the whole object - inline void getObjectRequest(const std::string& remotePath, std::pair range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool getObjectRequest(const std::string& remotePath, std::pair range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); auto originalMsg = std::make_unique(_provider->getRequest(remotePath, range), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } /// Build a new get request with callback /// Note that the range is [start, end[, [0, 0[ gets the whole object template - inline void getObjectRequest(Callback&& callback, const std::string& remotePath, std::pair range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool getObjectRequest(Callback&& callback, const std::string& remotePath, std::pair range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); auto originalMsg = std::make_unique>(std::forward(callback), _provider->getRequest(remotePath, range), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } /// Build a new put request for synchronous calls - inline void putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); if (_provider->multipartUploadSize() && size > _provider->multipartUploadSize()) return putObjectRequestMultiPart(remotePath, data, size, result, capacity, traceId); auto object = std::string_view(data, size); auto originalMsg = std::make_unique(_provider->putRequest(remotePath, object), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); originalMsg->setPutRequestData(reinterpret_cast(data), size); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } /// Build a new put request with callback template - inline void putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); if (_provider->multipartUploadSize() && size > _provider->multipartUploadSize()) return putObjectRequestMultiPart(std::forward(callback), remotePath, data, size, result, capacity, traceId); auto object = std::string_view(data, size); auto originalMsg = std::make_unique>(std::forward(callback), _provider->putRequest(remotePath, object), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); originalMsg->setPutRequestData(reinterpret_cast(data), size); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } /// Build a new delete request for synchronous calls - inline void deleteObjectRequest(const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool deleteObjectRequest(const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); auto originalMsg = std::make_unique(_provider->deleteRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } /// Build a new delete request with callback template - inline void deleteObjectRequest(Callback&& callback, const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool deleteObjectRequest(Callback&& callback, const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); auto originalMsg = std::make_unique>(std::forward(callback), _provider->deleteRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } private: /// Build a new put request for synchronous calls - inline void putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { auto finished = [](network::MessageResult& /*result*/) {}; - putObjectRequestMultiPart(std::move(finished), remotePath, data, size, result, capacity, traceId); + return putObjectRequestMultiPart(std::move(finished), remotePath, data, size, result, capacity, traceId); } /// Build a new put request with callback template - inline void putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { + inline bool putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) { assert(_provider); + _provider->getSecret(); auto splitSize = _provider->multipartUploadSize(); auto parts = (size / splitSize) + ((size % splitSize) ? 1u : 0u); _multipartUploads.emplace_back(parts); @@ -222,7 +257,10 @@ class Transaction { }; auto originalMsg = makeCallbackMessage(std::move(uploadMessages), _provider->createMultiPartRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId); + if (!originalMsg) + return false; _messages.push_back(std::move(originalMsg)); + return true; } public: diff --git a/src/cloud/aws.cpp b/src/cloud/aws.cpp index a6e27c8..0b194a0 100644 --- a/src/cloud/aws.cpp +++ b/src/cloud/aws.cpp @@ -6,8 +6,8 @@ #include "utils/data_vector.hpp" #include #include -#include #include +#include //--------------------------------------------------------------------------- // AnyBlob - Universal Cloud Object Storage Library // Dominik Durner, 2021 @@ -21,6 +21,9 @@ namespace cloud { //--------------------------------------------------------------------------- using namespace std; //--------------------------------------------------------------------------- +thread_local shared_ptr AWS::_secret = nullptr; +thread_local shared_ptr AWS::_sessionSecret = nullptr; +//--------------------------------------------------------------------------- static string buildAMZTimestamp() // Creates the AWS timestamp { @@ -148,7 +151,8 @@ bool AWS::updateSecret(string_view content, string_view iamUser) string timestamp(sv.begin(), sv.end()); secret->expiration = convertIAMTimestamp(timestamp); secret->iamUser = iamUser; - _secret.exchange(secret); + _globalSecret = secret; + _secret = secret; return true; } //--------------------------------------------------------------------------- @@ -194,15 +198,15 @@ bool AWS::updateSessionToken(string_view content) auto sv = content.substr(pos, end - pos); string timestamp(sv.begin(), sv.end()); secret->expiration = convertIAMTimestamp(timestamp); - _sessionSecret.exchange(secret); + _globalSessionSecret = secret; + _sessionSecret = secret; return true; } //--------------------------------------------------------------------------- bool AWS::validKeys(uint32_t offset) const // Checks whether keys need to be refresehd { - auto secret = _secret.load(); - if (!secret || ((!secret->token.empty() && secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || secret->secret.empty())) + if (!_secret || ((!_secret->token.empty() && _secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || _secret->secret.empty())) return false; return true; } @@ -210,8 +214,7 @@ bool AWS::validKeys(uint32_t offset) const bool AWS::validSession(uint32_t offset) const // Checks whether the session token needs to be refresehd { - auto secret = _sessionSecret.load(); - if (!secret || ((!secret->token.empty() && secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || secret->secret.empty())) + if (!_sessionSecret || ((!_sessionSecret->token.empty() && _sessionSecret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || _sessionSecret->secret.empty())) return false; return true; } @@ -220,27 +223,38 @@ void AWS::initSecret(network::TaskedSendReceiver& sendReceiver) // Uses the send receiver to initialize the secret { if (_type == Provider::CloudService::AWS && !validKeys(180)) { - auto secret = make_shared(); - auto message = downloadIAMUser(); - auto originalMsg = make_unique(move(message), getIAMAddress(), getIAMPort()); - sendReceiver.sendSync(originalMsg.get()); - sendReceiver.processSync(); - auto& content = originalMsg->result.getDataVector(); - unique_ptr infoPtr; - auto s = network::HTTPHelper::retrieveContent(content.cdata(), content.size(), infoPtr); - string iamUser; - message = downloadSecret(s, iamUser); - originalMsg = make_unique(move(message), getIAMAddress(), getIAMPort()); - sendReceiver.sendSync(originalMsg.get()); - sendReceiver.processSync(); - auto& secretContent = originalMsg->result.getDataVector(); - infoPtr.reset(); - s = network::HTTPHelper::retrieveContent(secretContent.cdata(), secretContent.size(), infoPtr); - updateSecret(s, iamUser); + while (true) { + if (_mutex.try_lock()) { + _secret = _globalSecret; + if (validKeys(180)) + return; + auto secret = make_shared(); + auto message = downloadIAMUser(); + auto originalMsg = make_unique(move(message), getIAMAddress(), getIAMPort()); + sendReceiver.sendSync(originalMsg.get()); + sendReceiver.processSync(); + auto& content = originalMsg->result.getDataVector(); + unique_ptr infoPtr; + auto s = network::HTTPHelper::retrieveContent(content.cdata(), content.size(), infoPtr); + string iamUser; + message = downloadSecret(s, iamUser); + originalMsg = make_unique(move(message), getIAMAddress(), getIAMPort()); + sendReceiver.sendSync(originalMsg.get()); + sendReceiver.processSync(); + auto& secretContent = originalMsg->result.getDataVector(); + infoPtr.reset(); + s = network::HTTPHelper::retrieveContent(secretContent.cdata(), secretContent.size(), infoPtr); + updateSecret(s, iamUser); + _mutex.unlock(); + } + if (validKeys(60)) + return; + } } if (_type == Provider::CloudService::AWS && _settings.zonal && !validSession(180)) { while (true) { if (_mutex.try_lock()) { + _sessionSecret = _globalSessionSecret; if (validSession(180)) return; @@ -260,6 +274,21 @@ void AWS::initSecret(network::TaskedSendReceiver& sendReceiver) } } //--------------------------------------------------------------------------- +void AWS::getSecret() +// Updates the local secret +{ + if (!_secret) { + _mutex.lock(); + _secret = _globalSecret; + _mutex.unlock(); + } + if (_type == Provider::CloudService::AWS && _settings.zonal && !_sessionSecret) { + _mutex.lock(); + _sessionSecret = _globalSessionSecret; + _mutex.unlock(); + } +} +//--------------------------------------------------------------------------- void AWS::initResolver(network::TaskedSendReceiver& sendReceiver) // Inits the resolver { @@ -277,11 +306,11 @@ unique_ptr> AWS::buildRequest(AWSSigner::Request& req request.headers.emplace("x-amz-date", testEnviornment ? fakeAMZTimestamp : buildAMZTimestamp()); if (!_settings.zonal) { request.headers.emplace("x-amz-request-payer", "requester"); - secret = _secret.load(); + secret = _secret; if (!secret->token.empty()) request.headers.emplace("x-amz-security-token", secret->token); } else { - secret = _sessionSecret.load(); + secret = _sessionSecret; request.headers.emplace("x-amz-s3session-token", secret->token); } } @@ -453,9 +482,8 @@ unique_ptr> AWS::getSessionToken(string_view type) co request.headers.emplace("Host", _settings.bucket + ".s3.amazonaws.com"); request.headers.emplace("x-amz-create-session-mode", type); request.headers.emplace("x-amz-date", testEnviornment ? fakeAMZTimestamp : buildAMZTimestamp()); - auto secret = _secret.load(); - if (!secret->token.empty()) - request.headers.emplace("x-amz-security-token", secret->token); + if (!_secret->token.empty()) + request.headers.emplace("x-amz-security-token", _secret->token); return buildRequest(request, false); } diff --git a/src/cloud/azure.cpp b/src/cloud/azure.cpp index 3eb41cc..d070ed4 100644 --- a/src/cloud/azure.cpp +++ b/src/cloud/azure.cpp @@ -5,6 +5,7 @@ #include "network/tasked_send_receiver.hpp" #include "network/resolver.hpp" #include "utils/data_vector.hpp" +#include #include #include #include diff --git a/test/integration/minio_async.cpp b/test/integration/minio_async.cpp index 38fdfde..2627531 100644 --- a/test/integration/minio_async.cpp +++ b/test/integration/minio_async.cpp @@ -73,7 +73,7 @@ TEST_CASE("MinIO Asynchronous Integration") { { // Check the upload for success - std::atomic finishedMessages = 0; + atomic finishedMessages = 0; auto checkSuccess = [&finishedMessages](anyblob::network::MessageResult& result) { // Sucessful request REQUIRE(result.success()); @@ -82,8 +82,12 @@ TEST_CASE("MinIO Asynchronous Integration") { // Create the put request anyblob::network::Transaction putTxn(provider.get()); - for (auto i = 0u; i < 2; i++) - putTxn.putObjectRequest(checkSuccess, fileName[i], content[i].data(), content[i].size()); + for (auto i = 0u; i < 2; i++) { + auto putObjectRequest = [&putTxn, &fileName, &content, &checkSuccess, i]() { + return putTxn.putObjectRequest(checkSuccess, fileName[i], content[i].data(), content[i].size()); + }; + putTxn.verifyKeyRequest(sendReceiver, move(putObjectRequest)); + } // Upload the request asynchronously putTxn.processAsync(group); @@ -94,7 +98,7 @@ TEST_CASE("MinIO Asynchronous Integration") { } { // Check the upload for success - std::atomic finishedMessages = 0; + atomic finishedMessages = 0; auto checkSuccess = [&finishedMessages](anyblob::network::MessageResult& result) { // Sucessful request REQUIRE(result.success()); @@ -105,8 +109,12 @@ TEST_CASE("MinIO Asynchronous Integration") { auto minio = static_cast(provider.get()); minio->setMultipartUploadSize(6ull << 20); anyblob::network::Transaction putTxn(provider.get()); - for (auto i = 0u; i < 2; i++) - putTxn.putObjectRequest(checkSuccess, fileName[i], content[i].data(), content[i].size()); + for (auto i = 0u; i < 2; i++) { + auto putObjectRequest = [&putTxn, &fileName, &content, &checkSuccess, i]() { + return putTxn.putObjectRequest(checkSuccess, fileName[i], content[i].data(), content[i].size()); + }; + putTxn.verifyKeyRequest(sendReceiver, move(putObjectRequest)); + } while (finishedMessages != 2) { // Upload the new request asynchronously @@ -117,7 +125,7 @@ TEST_CASE("MinIO Asynchronous Integration") { } { // Check the upload for failure due to too small part - std::atomic finishedMessages = 0; + atomic finishedMessages = 0; auto checkSuccess = [&finishedMessages](anyblob::network::MessageResult& result) { // Sucessful request REQUIRE(!result.success()); @@ -128,7 +136,11 @@ TEST_CASE("MinIO Asynchronous Integration") { auto minio = static_cast(provider.get()); minio->setMultipartUploadSize(1ull << 20); // too small, requires at least 5MiB parts anyblob::network::Transaction putTxn(provider.get()); - putTxn.putObjectRequest(checkSuccess, fileName[1], content[1].data(), content[1].size()); + + auto putObjectRequest = [&putTxn, &fileName, &content, callback = move(checkSuccess)]() { + return putTxn.putObjectRequest(move(callback), fileName[1], content[1].data(), content[1].size()); + }; + putTxn.verifyKeyRequest(sendReceiver, move(putObjectRequest)); while (finishedMessages != 1) { // Upload the new request asynchronously @@ -138,7 +150,7 @@ TEST_CASE("MinIO Asynchronous Integration") { } } { - std::atomic finishedMessages = 0; + atomic finishedMessages = 0; // Create the get request anyblob::network::Transaction getTxn(provider.get()); for (auto i = 0u; i < 2; i++) { @@ -163,7 +175,11 @@ TEST_CASE("MinIO Asynchronous Integration") { finishedMessages++; }; - getTxn.getObjectRequest(std::move(checkSuccess), fileName[i]); + auto& currentFileName = fileName[i]; + auto getObjectRequest = [&getTxn, ¤tFileName, callback = move(checkSuccess)]() { + return getTxn.getObjectRequest(move(callback), currentFileName); + }; + getTxn.verifyKeyRequest(sendReceiver, move(getObjectRequest)); } // Retrieve the request asynchronously @@ -175,7 +191,7 @@ TEST_CASE("MinIO Asynchronous Integration") { } { // Check the delete for success - std::atomic finishedMessages = 0; + atomic finishedMessages = 0; auto checkSuccess = [&finishedMessages](anyblob::network::MessageResult& result) { // Sucessful request REQUIRE(result.success()); @@ -184,8 +200,13 @@ TEST_CASE("MinIO Asynchronous Integration") { // Create the delete request anyblob::network::Transaction deleteTxn(provider.get()); - for (auto i = 0u; i < 2; i++) - deleteTxn.deleteObjectRequest(checkSuccess, fileName[i]); + for (auto i = 0u; i < 2; i++) { + auto& currentFileName = fileName[i]; + auto deleteRequest = [&deleteTxn, ¤tFileName, callback = move(checkSuccess)]() { + return deleteTxn.deleteObjectRequest(move(callback), currentFileName); + }; + deleteTxn.verifyKeyRequest(sendReceiver, move(deleteRequest)); + } // Process the request asynchronously deleteTxn.processAsync(group);