|
9 | 9 | #include "utility/Logging.hpp"
|
10 | 10 |
|
11 | 11 | #ifdef DEPTHAI_ENABLE_CURL
|
12 |
| - #include <cpr/cpr.h> |
| 12 | + #include <cpr/api.h> |
| 13 | + #include <cpr/parameters.h> |
| 14 | + #include <cpr/status_codes.h> |
13 | 15 | namespace dai {
|
14 | 16 | class ZooManager {
|
15 | 17 | public:
|
@@ -139,10 +141,10 @@ bool checkIsErrorHub(const cpr::Response& response) {
|
139 | 141 | // Check if response is an HTTP error
|
140 | 142 | if(response.status_code != cpr::status::HTTP_OK) return true;
|
141 | 143 |
|
142 |
| - // If there was no HTTP error, check response content for errors |
| 144 | + // If there was no HTTP error, check presence of required fields |
143 | 145 | nlohmann::json responseJson = nlohmann::json::parse(response.text);
|
144 |
| - if(responseJson.contains("errors")) return true; |
145 |
| - if(responseJson["data"]["ml"]["modelDownloads"].is_null()) return true; |
| 146 | + if(!responseJson.contains("hash")) return true; |
| 147 | + if(!responseJson.contains("download_links")) return true; |
146 | 148 |
|
147 | 149 | // All checks passed - no errors yay
|
148 | 150 | return false;
|
@@ -206,47 +208,53 @@ bool ZooManager::isModelCached() const {
|
206 | 208 | }
|
207 | 209 |
|
208 | 210 | void ZooManager::downloadModel() {
|
209 |
| - // graphql query to send to Hub - always the same |
210 |
| - constexpr std::string_view MODEL_ZOO_QUERY = "query MlDownloads($input: MlModelDownloadsInput!) {ml { modelDownloads(input : $input) { data }}}"; |
211 |
| - |
212 |
| - // Setup request body |
213 |
| - nlohmann::json requestBody; |
214 |
| - requestBody["query"] = MODEL_ZOO_QUERY; |
215 |
| - |
216 |
| - // Add REQUIRED parameters |
217 |
| - requestBody["variables"]["input"]["platform"] = modelDescription.platform; |
218 |
| - requestBody["variables"]["input"]["slug"] = modelDescription.model; |
219 |
| - |
220 |
| - // Add OPTIONAL parameters |
221 |
| - if(!modelDescription.optimizationLevel.empty()) { |
222 |
| - requestBody["variables"]["input"]["optimizationLevel"] = modelDescription.optimizationLevel; |
223 |
| - } |
224 |
| - if(!modelDescription.compressionLevel.empty()) { |
225 |
| - requestBody["variables"]["input"]["compressionLevel"] = modelDescription.compressionLevel; |
226 |
| - } |
227 |
| - if(!modelDescription.snpeVersion.empty()) { |
228 |
| - requestBody["variables"]["input"]["snpeVersion"] = modelDescription.snpeVersion; |
| 211 | + // Add request parameters |
| 212 | + cpr::Parameters params; |
| 213 | + |
| 214 | + // Required parameters |
| 215 | + // clang-format off |
| 216 | + std::vector<std::pair<std::string, std::string>> requiredParams = { |
| 217 | + {"slug", modelDescription.model}, |
| 218 | + {"platform", modelDescription.platform} |
| 219 | + }; |
| 220 | + // clang-format on |
| 221 | + for(const auto& param : requiredParams) { |
| 222 | + params.Add({param.first, param.second}); |
229 | 223 | }
|
230 |
| - if(!modelDescription.modelPrecisionType.empty()) { |
231 |
| - requestBody["variables"]["input"]["modelPrecisionType"] = modelDescription.modelPrecisionType; |
| 224 | + |
| 225 | + // Optional parameters |
| 226 | + // clang-format off |
| 227 | + std::vector<std::pair<std::string, std::string>> optionalParams = { |
| 228 | + {"optimizationLevel", modelDescription.optimizationLevel}, |
| 229 | + {"compressionLevel", modelDescription.compressionLevel}, |
| 230 | + {"snpeVersion", modelDescription.snpeVersion}, |
| 231 | + {"modelPrecisionType", modelDescription.modelPrecisionType} |
| 232 | + }; |
| 233 | + // clang-format on |
| 234 | + for(const auto& param : optionalParams) { |
| 235 | + if(!param.second.empty()) { |
| 236 | + params.Add({param.first, param.second}); |
| 237 | + } |
232 | 238 | }
|
| 239 | + |
233 | 240 | // Set the Authorization headers
|
234 | 241 | cpr::Header headers = {
|
235 | 242 | {"Content-Type", "application/json"},
|
236 | 243 | };
|
237 | 244 | if(!apiKey.empty()) {
|
238 | 245 | headers["Authorization"] = "Bearer " + apiKey;
|
239 | 246 | }
|
240 |
| - // Send HTTP request to Hub |
241 |
| - cpr::Response response = cpr::Post(cpr::Url{MODEL_ZOO_URL}, headers, cpr::Body{requestBody.dump()}); |
| 247 | + |
| 248 | + // Send HTTP GET request to REST endpoint |
| 249 | + cpr::Response response = cpr::Get(cpr::Url{MODEL_ZOO_URL}, headers, params); |
242 | 250 | if(checkIsErrorHub(response)) {
|
243 | 251 | removeModelCacheFolder();
|
244 | 252 | throw std::runtime_error(generateErrorMessageHub(response));
|
245 | 253 | }
|
246 | 254 |
|
247 |
| - // Extract download link from response |
| 255 | + // Extract download links from response |
248 | 256 | nlohmann::json responseJson = nlohmann::json::parse(response.text);
|
249 |
| - auto downloadLinks = responseJson["data"]["ml"]["modelDownloads"]["data"].get<std::vector<std::string>>(); |
| 257 | + auto downloadLinks = responseJson["download_links"].get<std::vector<std::string>>(); |
250 | 258 |
|
251 | 259 | // Download all files and store them in cache folder
|
252 | 260 | for(const auto& downloadLink : downloadLinks) {
|
|
0 commit comments