diff --git a/lib/http/HttpClient_WinInet.cpp b/lib/http/HttpClient_WinInet.cpp index 0589f9796..54875f7ad 100644 --- a/lib/http/HttpClient_WinInet.cpp +++ b/lib/http/HttpClient_WinInet.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -22,6 +23,8 @@ namespace MAT_NS_BEGIN { +const std::string kProxyRegKeyPath = "Software\\Microsoft\\Windows\\CurrentVersion\\Internet Settings"; + class WinInetRequestWrapper { protected: @@ -136,6 +139,105 @@ class WinInetRequestWrapper return true; } + DWORD GetDWORDRegKey(HKEY hKey, const std::string& subKey, const std::string& value) { + DWORD data; + DWORD dataSize = sizeof(data); + LONG res = RegGetValueA(hKey, subKey.c_str(), value.c_str(), RRF_RT_REG_DWORD, NULL, &data, + &dataSize); + if (res != ERROR_SUCCESS) { + return 0; + } + return data; + } + + std::string GetStringRegKey(HKEY hKey, const std::string& subKey, const std::string& value) { + DWORD stringSize; + LONG getSizeResult = RegGetValueA(hKey, subKey.c_str(), value.c_str(), RRF_RT_REG_SZ, NULL, + NULL, &stringSize); + if (getSizeResult != ERROR_SUCCESS) { + return ""; + } + + std::string data; + data.resize(stringSize); + LONG getStringResult = RegGetValueA(hKey, subKey.c_str(), value.c_str(), RRF_RT_REG_SZ, NULL, + &data[0], &stringSize); + if (getStringResult != ERROR_SUCCESS) { + return ""; + } + + // Remove the null terminator from Win32 API. The std::string will already cover this. + data.resize(stringSize - 1); + + return data; + } + + bool ProxyAuthRequired() { + return GetDWORDRegKey(HKEY_CURRENT_USER, kProxyRegKeyPath, "ProxyEnable"); + } + + void SetProxyCredentials() { + // get proxy server from registry + std::string proxyServer = GetStringRegKey(HKEY_CURRENT_USER, kProxyRegKeyPath, "ProxyServer"); + if (proxyServer.empty()) { + return; + } + + std::wstring formattedProxyServer; + + // If there is only one colon, then it is most likely of form IP:PORT and we just need + // to remove the port. If there are two colons, then it is most likely of form SCHEME:IP:PORT + int colonCount = 0; + for (const char& c : proxyServer) { + if (c == ':') { + colonCount++; + } + } + + if (colonCount == 1) { + const auto colonPos = proxyServer.find_first_of(':'); + auto portRemoved = proxyServer.substr(0, colonPos); + formattedProxyServer = std::wstring(portRemoved.begin(), portRemoved.end()); + } else if (colonCount == 2) { + const auto firstColonPos = proxyServer.find_first_of(':'); + auto schemeRemoved = proxyServer.substr(firstColonPos + 3); + + const auto lastColonPos = schemeRemoved.find_last_of(':'); + auto portAndSchemeRemoved = schemeRemoved.substr(0, lastColonPos); + + formattedProxyServer = std::wstring(portAndSchemeRemoved.begin(), + portAndSchemeRemoved.end()); + } else { + formattedProxyServer = std::wstring(proxyServer.begin(), + proxyServer.end()); + } + + // get proxy credentials from credential manager + PCREDENTIALW cred; + if (!::CredReadW(formattedProxyServer.c_str(), CRED_TYPE_GENERIC, 0, &cred)) { + return; + } + wchar_t* proxyUser = cred->UserName; + wchar_t* proxyPass = (wchar_t*)cred->CredentialBlob; + + // set proxy credentials + if (proxyUser) { + ::InternetSetOptionW(m_hWinInetRequest, + INTERNET_OPTION_PROXY_USERNAME, + static_cast(proxyUser), + static_cast(wcslen(proxyUser) + 1)); + } + + if (proxyPass) { + ::InternetSetOptionW(m_hWinInetRequest, + INTERNET_OPTION_PROXY_PASSWORD, + static_cast(proxyPass), + static_cast(wcslen(proxyPass) + 1)); + } + + ::CredFree(cred); + } + // Asynchronously send HTTP request and invoke response callback. // Ownership semantics: send(...) method self-destroys *this* upon // receiving WinInet callback. There must be absolutely no methods @@ -209,7 +311,7 @@ class WinInetRequestWrapper PCSTR szAcceptTypes[] = {"*/*", NULL}; m_hWinInetRequest = ::HttpOpenRequestA( m_hWinInetSession, m_request->m_method.c_str(), path, NULL, NULL, szAcceptTypes, - INTERNET_FLAG_KEEP_CONNECTION | INTERNET_FLAG_NO_AUTH | INTERNET_FLAG_NO_CACHE_WRITE | + INTERNET_FLAG_KEEP_CONNECTION | INTERNET_FLAG_NO_CACHE_WRITE | INTERNET_FLAG_NO_COOKIES | INTERNET_FLAG_NO_UI | INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_RELOAD | (urlc.nScheme == INTERNET_SCHEME_HTTPS ? INTERNET_FLAG_SECURE : 0), reinterpret_cast(this)); @@ -252,6 +354,10 @@ class WinInetRequestWrapper return; } + if (ProxyAuthRequired()) { + SetProxyCredentials(); + } + // Try to send headers and request body to server DispatchEvent(OnSending); void *data = static_cast(m_request->m_body.data()); diff --git a/lib/http/HttpResponseDecoder.cpp b/lib/http/HttpResponseDecoder.cpp index 5330fdf9f..449c2bdcb 100644 --- a/lib/http/HttpResponseDecoder.cpp +++ b/lib/http/HttpResponseDecoder.cpp @@ -5,8 +5,8 @@ #include "HttpResponseDecoder.hpp" #include "ILogManager.hpp" -#include #include "utils/Utils.hpp" +#include #include #include @@ -14,10 +14,9 @@ #include "json.hpp" #endif -namespace MAT_NS_BEGIN { - - HttpResponseDecoder::HttpResponseDecoder(ITelemetrySystem& system) - : +namespace MAT_NS_BEGIN +{ + HttpResponseDecoder::HttpResponseDecoder(ITelemetrySystem& system) : m_system(system) { } @@ -46,17 +45,18 @@ namespace MAT_NS_BEGIN { #endif IHttpResponse const& response = *(ctx->httpResponse); - IHttpRequest & request = *(ctx->httpRequest); + IHttpRequest& request = *(ctx->httpRequest); HttpRequestResult outcome = Abort; auto result = response.GetResult(); - switch (result) { + switch (result) + { case HttpResult_OK: if (response.GetStatusCode() == 200) { outcome = Accepted; } - else if (response.GetStatusCode() >= 500 || response.GetStatusCode() == 408 || response.GetStatusCode() == 429) + else if (response.GetStatusCode() >= 500 || response.GetStatusCode() == 408 || response.GetStatusCode() == 429 || response.GetStatusCode() == 407) { outcome = RetryServer; } @@ -83,15 +83,17 @@ namespace MAT_NS_BEGIN { processBody(response, outcome); } - switch (outcome) { - case Accepted: { + switch (outcome) + { + case Accepted: + { LOG_INFO("HTTP request %s finished after %d ms, events were successfully uploaded to the server", - response.GetId().c_str(), ctx->durationMs); + response.GetId().c_str(), ctx->durationMs); { DebugEvent evt; evt.type = DebugEventType::EVT_HTTP_OK; evt.param1 = response.GetStatusCode(); - evt.data = static_cast(request.GetBody().data()); + evt.data = static_cast(request.GetBody().data()); evt.size = request.GetBody().size(); DispatchEvent(evt); } @@ -99,9 +101,10 @@ namespace MAT_NS_BEGIN { break; } - case Rejected: { + case Rejected: + { LOG_ERROR("HTTP request %s failed after %d ms, events were rejected by the server (%u) and will be all dropped", - response.GetId().c_str(), ctx->durationMs, response.GetStatusCode()); + response.GetId().c_str(), ctx->durationMs, response.GetStatusCode()); std::string body(reinterpret_cast(response.GetBody().data()), std::min(response.GetBody().size(), 100)); LOG_TRACE("Server response: %s%s", body.c_str(), (response.GetBody().size() > body.size()) ? "..." : ""); { @@ -112,7 +115,7 @@ namespace MAT_NS_BEGIN { // This is to be addressed with ETW trace API that can send // a detailed error context to ETW provider. evt.param1 = response.GetStatusCode(); - evt.data = static_cast(request.GetBody().data()); + evt.data = static_cast(request.GetBody().data()); evt.size = request.GetBody().size(); DispatchEvent(evt); eventsRejected(ctx); @@ -120,13 +123,14 @@ namespace MAT_NS_BEGIN { break; } - case Abort: { + case Abort: + { LOG_WARN("HTTP request %s failed after %d ms, upload was aborted and events will be sent at a different time", - response.GetId().c_str(), ctx->durationMs); + response.GetId().c_str(), ctx->durationMs); { DebugEvent evt; evt.type = DebugEventType::EVT_HTTP_FAILURE; - evt.param1 = 0; // response.GetStatusCode(); + evt.param1 = 0; // response.GetStatusCode(); DispatchEvent(evt); } ctx->httpResponse = nullptr; @@ -135,9 +139,10 @@ namespace MAT_NS_BEGIN { break; } - case RetryServer: { + case RetryServer: + { LOG_WARN("HTTP request %s failed after %d ms, a temporary server error has occurred (%u) and events will be sent at a different time", - response.GetId().c_str(), ctx->durationMs, response.GetStatusCode()); + response.GetId().c_str(), ctx->durationMs, response.GetStatusCode()); std::string body(reinterpret_cast(response.GetBody().data()), std::min(response.GetBody().size(), 100)); LOG_TRACE("Server response: %s%s", body.c_str(), (response.GetBody().size() > body.size()) ? "..." : ""); { @@ -150,9 +155,10 @@ namespace MAT_NS_BEGIN { break; } - case RetryNetwork: { + case RetryNetwork: + { LOG_WARN("HTTP request %s failed after %d ms, a network error has occurred and events will be sent at a different time", - response.GetId().c_str(), ctx->durationMs); + response.GetId().c_str(), ctx->durationMs); { DebugEvent evt; evt.type = DebugEventType::EVT_HTTP_FAILURE; @@ -165,7 +171,7 @@ namespace MAT_NS_BEGIN { } } - void HttpResponseDecoder::processBody(IHttpResponse const& response, HttpRequestResult & result) + void HttpResponseDecoder::processBody(IHttpResponse const& response, HttpRequestResult& result) { #ifdef HAVE_MAT_JSONHPP // TODO: [MG] - parse HTTP response without json.hpp library @@ -227,7 +233,8 @@ namespace MAT_NS_BEGIN { if (result != Rejected) { LOG_TRACE("HTTP response: accepted=%d rejected=%d", accepted, rejected); - } else + } + else { LOG_TRACE("HTTP response: all rejected"); } @@ -242,5 +249,5 @@ namespace MAT_NS_BEGIN { #endif } -} MAT_NS_END - +} +MAT_NS_END diff --git a/lib/include/public/IECSClient.hpp b/lib/include/public/IECSClient.hpp index 227844e3d..6864d1e62 100644 --- a/lib/include/public/IECSClient.hpp +++ b/lib/include/public/IECSClient.hpp @@ -9,6 +9,7 @@ #include "ILogger.hpp" #include +#include #include #include @@ -51,6 +52,10 @@ namespace Microsoft { // [optional] enabled ECS telemetry bool enableECSClientTelemetry = false; + + // [optional] Mandatory agents list. If not present in response, the payload + // is determined as bad + std::unordered_set mandatoryAgents; }; ///