Skip to content

Commit

Permalink
Added verifyKeyRequest method
Browse files Browse the repository at this point in the history
  • Loading branch information
durner committed Dec 16, 2023
1 parent 274c341 commit 6873784
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 59 deletions.
17 changes: 11 additions & 6 deletions include/cloud/aws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,18 @@ class AWS : public Provider {
/// The settings
Settings _settings;
/// The global secret
std::atomic<std::shared_ptr<Secret>> _secret;
/// The session secret
std::atomic<std::shared_ptr<Secret>> _sessionSecret;
std::shared_ptr<Secret> _globalSecret;
/// The global session secret
std::shared_ptr<Secret> _globalSessionSecret;
/// The multipart upload size
uint64_t _multipartUploadSize = 128ull << 20;
/// The secret mutex
std::mutex _mutex;
/// The thread local secret
thread_local static std::shared_ptr<Secret> _secret;
/// The session secret
thread_local static std::shared_ptr<Secret> _sessionSecret;


public:
/// Get instance details
Expand All @@ -96,10 +101,10 @@ class AWS : public Provider {
}
/// The custom endpoint constructor
AWS(const RemoteInfo& info, const std::string& keyId, const std::string& key) : AWS(info) {
_secret = std::make_unique<Secret>();
_globalSecret = std::make_unique<Secret>();
// At init it is fine to simply overwrite
_secret.load()->keyId = keyId;
_secret.load()->secret = key;
_globalSecret->keyId = keyId;
_globalSecret->secret = key;
}

private:
Expand Down
55 changes: 43 additions & 12 deletions include/network/transaction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Transaction {
};

/// The provider
const cloud::Provider* _provider;
cloud::Provider* _provider;

/// The message
message_vector_type _messages;
Expand All @@ -91,80 +91,108 @@ class Transaction {
/// The constructor
Transaction() : _provider(), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {}
/// The explicit constructor with the provider
explicit Transaction(const cloud::Provider* provider) : _provider(provider), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {}
explicit Transaction(cloud::Provider* provider) : _provider(provider), _messages(), _messageCounter(), _multipartUploads(), _completedMultiparts() {}

/// Set the provider
constexpr void setProvider(const cloud::Provider* provider) { this->_provider = provider; }
constexpr void setProvider(cloud::Provider* provider) { this->_provider = provider; }
/// Sends the request messages to the task group
void processAsync(network::TaskedSendReceiverGroup& group);
/// Processes the request messages
void processSync(TaskedSendReceiver& sendReceiver);

/// Function to ensure fresh keys before creating messages
/// This is needed to ensure valid keys before a message is requested
/// Simply forward a task send receiver the message function and the args of this message
template <typename Function>
bool verifyKeyRequest(TaskedSendReceiver& sendReceiver, Function&& func) {
assert(_provider);
_provider->initSecret(sendReceiver);
return std::forward<Function>(func)();
}

/// Build a new get request for synchronous calls
/// Note that the range is [start, end[, [0, 0[ gets the whole object
inline void getObjectRequest(const std::string& remotePath, std::pair<uint64_t, uint64_t> range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool getObjectRequest(const std::string& remotePath, std::pair<uint64_t, uint64_t> range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
auto originalMsg = std::make_unique<network::OriginalMessage>(_provider->getRequest(remotePath, range), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

/// Build a new get request with callback
/// Note that the range is [start, end[, [0, 0[ gets the whole object
template <typename Callback>
inline void getObjectRequest(Callback&& callback, const std::string& remotePath, std::pair<uint64_t, uint64_t> range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool getObjectRequest(Callback&& callback, const std::string& remotePath, std::pair<uint64_t, uint64_t> range = {0, 0}, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), _provider->getRequest(remotePath, range), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

/// Build a new put request for synchronous calls
inline void putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool putObjectRequest(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
if (_provider->multipartUploadSize() && size > _provider->multipartUploadSize())
return putObjectRequestMultiPart(remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalMessage>(_provider->putRequest(remotePath, object), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

/// Build a new put request with callback
template <typename Callback>
inline void putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool putObjectRequest(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
if (_provider->multipartUploadSize() && size > _provider->multipartUploadSize())
return putObjectRequestMultiPart(std::forward<Callback>(callback), remotePath, data, size, result, capacity, traceId);
auto object = std::string_view(data, size);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), _provider->putRequest(remotePath, object), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
originalMsg->setPutRequestData(reinterpret_cast<const uint8_t*>(data), size);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

/// Build a new delete request for synchronous calls
inline void deleteObjectRequest(const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool deleteObjectRequest(const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
auto originalMsg = std::make_unique<network::OriginalMessage>(_provider->deleteRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

/// Build a new delete request with callback
template <typename Callback>
inline void deleteObjectRequest(Callback&& callback, const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool deleteObjectRequest(Callback&& callback, const std::string& remotePath, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
auto originalMsg = std::make_unique<network::OriginalCallbackMessage<Callback>>(std::forward<Callback>(callback), _provider->deleteRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

private:
/// Build a new put request for synchronous calls
inline void putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool putObjectRequestMultiPart(const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
auto finished = [](network::MessageResult& /*result*/) {};
putObjectRequestMultiPart(std::move(finished), remotePath, data, size, result, capacity, traceId);
return putObjectRequestMultiPart(std::move(finished), remotePath, data, size, result, capacity, traceId);
}

/// Build a new put request with callback
template <typename Callback>
inline void putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
inline bool putObjectRequestMultiPart(Callback&& callback, const std::string& remotePath, const char* data, uint64_t size, uint8_t* result = nullptr, uint64_t capacity = 0, uint64_t traceId = 0) {
assert(_provider);
auto splitSize = _provider->multipartUploadSize();
auto parts = (size / splitSize) + ((size % splitSize) ? 1u : 0u);
Expand Down Expand Up @@ -222,7 +250,10 @@ class Transaction {
};

auto originalMsg = makeCallbackMessage(std::move(uploadMessages), _provider->createMultiPartRequest(remotePath), _provider->getAddress(), _provider->getPort(), result, capacity, traceId);
if (!originalMsg)
return false;
_messages.push_back(std::move(originalMsg));
return true;
}

public:
Expand Down
69 changes: 41 additions & 28 deletions src/cloud/aws.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ namespace cloud {
//---------------------------------------------------------------------------
using namespace std;
//---------------------------------------------------------------------------
thread_local shared_ptr<AWS::Secret> AWS::_secret = nullptr;
thread_local shared_ptr<AWS::Secret> AWS::_sessionSecret = nullptr;
//---------------------------------------------------------------------------
static string buildAMZTimestamp()
// Creates the AWS timestamp
{
Expand Down Expand Up @@ -148,7 +151,8 @@ bool AWS::updateSecret(string_view content, string_view iamUser)
string timestamp(sv.begin(), sv.end());
secret->expiration = convertIAMTimestamp(timestamp);
secret->iamUser = iamUser;
_secret.exchange(secret);
_globalSecret = secret;
_secret = secret;
return true;
}
//---------------------------------------------------------------------------
Expand Down Expand Up @@ -194,24 +198,23 @@ bool AWS::updateSessionToken(string_view content)
auto sv = content.substr(pos, end - pos);
string timestamp(sv.begin(), sv.end());
secret->expiration = convertIAMTimestamp(timestamp);
_sessionSecret.exchange(secret);
_globalSessionSecret = secret;
_sessionSecret = secret;
return true;
}
//---------------------------------------------------------------------------
bool AWS::validKeys(uint32_t offset) const
// Checks whether keys need to be refresehd
{
auto secret = _secret.load();
if (!secret || ((!secret->token.empty() && secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || secret->secret.empty()))
if (!_secret || ((!_secret->token.empty() && _secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || _secret->secret.empty()))
return false;
return true;
}
//---------------------------------------------------------------------------
bool AWS::validSession(uint32_t offset) const
// Checks whether the session token needs to be refresehd
{
auto secret = _sessionSecret.load();
if (!secret || ((!secret->token.empty() && secret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || secret->secret.empty()))
if (!_sessionSecret || ((!_sessionSecret->token.empty() && _sessionSecret->expiration - offset < chrono::system_clock::to_time_t(chrono::system_clock::now())) || _sessionSecret->secret.empty()))
return false;
return true;
}
Expand All @@ -220,27 +223,38 @@ void AWS::initSecret(network::TaskedSendReceiver& sendReceiver)
// Uses the send receiver to initialize the secret
{
if (_type == Provider::CloudService::AWS && !validKeys(180)) {
auto secret = make_shared<Secret>();
auto message = downloadIAMUser();
auto originalMsg = make_unique<network::OriginalMessage>(move(message), getIAMAddress(), getIAMPort());
sendReceiver.sendSync(originalMsg.get());
sendReceiver.processSync();
auto& content = originalMsg->result.getDataVector();
unique_ptr<network::HTTPHelper::Info> infoPtr;
auto s = network::HTTPHelper::retrieveContent(content.cdata(), content.size(), infoPtr);
string iamUser;
message = downloadSecret(s, iamUser);
originalMsg = make_unique<network::OriginalMessage>(move(message), getIAMAddress(), getIAMPort());
sendReceiver.sendSync(originalMsg.get());
sendReceiver.processSync();
auto& secretContent = originalMsg->result.getDataVector();
infoPtr.reset();
s = network::HTTPHelper::retrieveContent(secretContent.cdata(), secretContent.size(), infoPtr);
updateSecret(s, iamUser);
while (true) {
if (_mutex.try_lock()) {
_secret = _globalSecret;
if (validKeys(180))
return;
auto secret = make_shared<Secret>();
auto message = downloadIAMUser();
auto originalMsg = make_unique<network::OriginalMessage>(move(message), getIAMAddress(), getIAMPort());
sendReceiver.sendSync(originalMsg.get());
sendReceiver.processSync();
auto& content = originalMsg->result.getDataVector();
unique_ptr<network::HTTPHelper::Info> infoPtr;
auto s = network::HTTPHelper::retrieveContent(content.cdata(), content.size(), infoPtr);
string iamUser;
message = downloadSecret(s, iamUser);
originalMsg = make_unique<network::OriginalMessage>(move(message), getIAMAddress(), getIAMPort());
sendReceiver.sendSync(originalMsg.get());
sendReceiver.processSync();
auto& secretContent = originalMsg->result.getDataVector();
infoPtr.reset();
s = network::HTTPHelper::retrieveContent(secretContent.cdata(), secretContent.size(), infoPtr);
updateSecret(s, iamUser);
_mutex.unlock();
}
if (validKeys(60))
return;
}
}
if (_type == Provider::CloudService::AWS && _settings.zonal && !validSession(180)) {
while (true) {
if (_mutex.try_lock()) {
_sessionSecret = _globalSessionSecret;
if (validSession(180))
return;

Expand Down Expand Up @@ -277,11 +291,11 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::buildRequest(AWSSigner::Request& req
request.headers.emplace("x-amz-date", testEnviornment ? fakeAMZTimestamp : buildAMZTimestamp());
if (!_settings.zonal) {
request.headers.emplace("x-amz-request-payer", "requester");
secret = _secret.load();
secret = _secret;
if (!secret->token.empty())
request.headers.emplace("x-amz-security-token", secret->token);
} else {
secret = _sessionSecret.load();
secret = _sessionSecret;
request.headers.emplace("x-amz-s3session-token", secret->token);
}
}
Expand Down Expand Up @@ -453,9 +467,8 @@ unique_ptr<utils::DataVector<uint8_t>> AWS::getSessionToken(string_view type) co
request.headers.emplace("Host", _settings.bucket + ".s3.amazonaws.com");
request.headers.emplace("x-amz-create-session-mode", type);
request.headers.emplace("x-amz-date", testEnviornment ? fakeAMZTimestamp : buildAMZTimestamp());
auto secret = _secret.load();
if (!secret->token.empty())
request.headers.emplace("x-amz-security-token", secret->token);
if (!_secret->token.empty())
request.headers.emplace("x-amz-security-token", _secret->token);

return buildRequest(request, false);
}
Expand Down
1 change: 1 addition & 0 deletions src/cloud/azure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "network/tasked_send_receiver.hpp"
#include "network/resolver.hpp"
#include "utils/data_vector.hpp"
#include <algorithm>
#include <chrono>
#include <iomanip>
#include <sstream>
Expand Down
Loading

0 comments on commit 6873784

Please sign in to comment.