Skip to content

Commit

Permalink
feat: introduce predict output parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
sileht committed Jan 28, 2021
1 parent 1e178c9 commit 7973b92
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 64 deletions.
71 changes: 21 additions & 50 deletions src/backends/ncnn/ncnnlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,63 +186,33 @@ namespace dd
}

APIData ad_output = ad.getobj("parameters").getobj("output");

// Get bbox
bool bbox = false;
if (ad_output.has("bbox"))
bbox = ad_output.get("bbox").get<bool>();

// Ctc model
bool ctc = false;
int blank_label = -1;
if (ad_output.has("ctc"))
{
ctc = ad_output.get("ctc").get<bool>();
if (ctc)
{
if (ad_output.has("blank_label"))
blank_label = ad_output.get("blank_label").get<int>();
}
}
auto output_params
= ad_output.createSharedDTO<PredictOutputParametersDto>();

// Extract detection or classification
int ret = 0;
std::string out_blob;
if (_init_dto->outputBlob != nullptr)
out_blob = _init_dto->outputBlob->std_str();

if (out_blob.empty())
{
if (bbox == true)
if (output_params->bbox == true)
out_blob = "detection_out";
else if (ctc == true)
else if (output_params->ctc == true)
out_blob = "probs";
else if (_timeserie)
out_blob = "rnn_pred";
else
out_blob = "prob";
}

std::vector<APIData> vrad;

// Get confidence_threshold
float confidence_threshold = 0.0;
if (ad_output.has("confidence_threshold"))
{
apitools::get_float(ad_output, "confidence_threshold",
confidence_threshold);
}

// Get best
int best = -1;
if (ad_output.has("best"))
{
best = ad_output.get("best").get<int>();
}
if (best == -1 || best > _init_dto->nclasses)
best = _init_dto->nclasses;
if (output_params->best == -1 || output_params->best > _init_dto->nclasses)
output_params->best = _init_dto->nclasses;

std::vector<APIData> vrad;

// for loop around batch size
// for loop around batch size
#pragma omp parallel for num_threads(*_init_dto->threads)
for (size_t b = 0; b < inputc._ids.size(); b++)
{
Expand All @@ -256,13 +226,13 @@ namespace dd
ex.set_num_threads(_init_dto->threads);
ex.input(_init_dto->inputBlob->c_str(), inputc._in.at(b));

ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
int ret = ex.extract(out_blob.c_str(), inputc._out.at(b));
if (ret == -1)
{
throw MLLibInternalException("NCNN internal error");
}

if (bbox == true)
if (output_params->bbox == true)
{
std::string uri = inputc._ids.at(b);
auto bit = inputc._imgs_size.find(uri);
Expand All @@ -282,7 +252,7 @@ namespace dd
for (int i = 0; i < inputc._out.at(b).h; i++)
{
const float *values = inputc._out.at(b).row(i);
if (values[1] < confidence_threshold)
if (values[1] < output_params->confidence_threshold)
break; // output is sorted by confidence

cats.push_back(this->_mlmodel.get_hcorresp(values[0]));
Expand All @@ -300,7 +270,7 @@ namespace dd
bboxes.push_back(ad_bbox);
}
}
else if (ctc == true)
else if (output_params->ctc == true)
{
int alphabet = inputc._out.at(b).w;
int time_step = inputc._out.at(b).h;
Expand All @@ -313,11 +283,11 @@ namespace dd
}

std::vector<int> pred_label_seq;
int prev = blank_label;
int prev = output_params->blank_label;
for (int t = 0; t < time_step; ++t)
{
int cur = pred_label_seq_with_blank[t];
if (cur != prev && cur != blank_label)
if (cur != prev && cur != output_params->blank_label)
pred_label_seq.push_back(cur);
prev = cur;
}
Expand Down Expand Up @@ -365,12 +335,13 @@ namespace dd
vec[i] = std::make_pair(cls_scores[i], i);
}

std::partial_sort(vec.begin(), vec.begin() + best, vec.end(),
std::partial_sort(vec.begin(), vec.begin() + output_params->best,
vec.end(),
std::greater<std::pair<float, int>>());

for (int i = 0; i < best; i++)
for (int i = 0; i < output_params->best; i++)
{
if (vec[i].first < confidence_threshold)
if (vec[i].first < output_params->confidence_threshold)
continue;
cats.push_back(this->_mlmodel.get_hcorresp(vec[i].second));
probs.push_back(vec[i].first);
Expand All @@ -380,7 +351,7 @@ namespace dd
rad.add("uri", inputc._ids.at(b));
rad.add("loss", 0.0);
rad.add("cats", cats);
if (bbox == true)
if (output_params->bbox == true)
rad.add("bboxes", bboxes);
if (_timeserie)
{
Expand All @@ -402,7 +373,7 @@ namespace dd
tout.add_results(vrad);
int nclasses = this->_init_dto->nclasses;
out.add("nclasses", nclasses);
if (bbox == true)
if (output_params->bbox == true)
out.add("bbox", true);
out.add("roi", false);
out.add("multibox_rois", false);
Expand Down
56 changes: 56 additions & 0 deletions src/http/dto/predict.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/**
* DeepDetect
* Copyright (c) 2021 Jolibrain SASU
* Author: Mehdi Abaakouk <[email protected]>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef HTTP_DTO_PREDICT_H
#define HTTP_DTO_PREDICT_H
#include "dd_config.h"
#include "oatpp/core/Types.hpp"
#include "oatpp/core/macro/codegen.hpp"

#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section

class PredictOutputParametersDto : public oatpp::DTO
{
DTO_INIT(PredictOutputParametersDto, DTO /* extends */)

/* ncnn */
DTO_FIELD(Boolean, bbox) = false;
DTO_FIELD(Boolean, ctc) = false;
DTO_FIELD(Int32, blank_label) = -1;
DTO_FIELD(Float32, confidence_threshold) = 0.0;

/* ncnn && supervised init && supervised predict */
DTO_FIELD(Int32, best) = -1;

/* output supervised init */
DTO_FIELD(Boolean, nclasses) = false; // Looks like a bug ?

/* output supervised predict */
DTO_FIELD(Boolean, index) = false;
DTO_FIELD(Boolean, build_index) = false;
DTO_FIELD(Boolean, search) = false;
DTO_FIELD(Int32, search_nn);
DTO_FIELD(Int32, nprobe);
};

#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section

#endif
34 changes: 20 additions & 14 deletions src/supervisedoutputconnector.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#define SUPERVISEDOUTPUTCONNECTOR_H
#define TS_METRICS_EPSILON 1E-2

#include "http/dto/predict.hpp"

template <typename T>
bool SortScorePairDescend(const std::pair<double, T> &pair1,
const std::pair<double, T> &pair2)
Expand Down Expand Up @@ -161,10 +163,11 @@ namespace dd
void init(const APIData &ad)
{
APIData ad_out = ad.getobj("parameters").getobj("output");
if (ad_out.has("best"))
_best = ad_out.get("best").get<int>();
auto output_params
= ad_out.createSharedDTO<PredictOutputParametersDto>();
_best = output_params->best;
if (_best == -1)
_best = ad_out.get("nclasses").get<int>();
_best = output_params->nclasses;
}

/**
Expand Down Expand Up @@ -242,13 +245,13 @@ namespace dd
* @param ad_out output data object
* @param bcats supervised output connector
*/
void best_cats(const APIData &ad_out, SupervisedOutput &bcats,
void best_cats(SupervisedOutput &bcats, const int &output_param_best,
const int &nclasses, const bool &has_bbox,
const bool &has_roi, const bool &has_mask) const
{
int best = _best;
if (ad_out.has("best"))
best = ad_out.get("best").get<int>();
if (output_param_best != -1)
best = output_param_best;
if (best == -1)
best = nclasses;
if (!has_bbox && !has_roi && !has_mask)
Expand Down Expand Up @@ -399,6 +402,8 @@ namespace dd
*/
void finalize(const APIData &ad_in, APIData &ad_out, MLModel *mlm)
{
auto output_params = ad_in.createSharedDTO<PredictOutputParametersDto>();

#ifndef USE_SIMSEARCH
(void)mlm;
#endif
Expand Down Expand Up @@ -443,12 +448,13 @@ namespace dd
}

if (!timeseries)
best_cats(ad_in, bcats, nclasses, has_bbox, has_roi, has_mask);
best_cats(bcats, output_params->best, nclasses, has_bbox, has_roi,
has_mask);

std::unordered_set<std::string> indexed_uris;
#ifdef USE_SIMSEARCH
// index
if (ad_in.has("index") && ad_in.get("index").get<bool>())
if (output_params->index)
{
// check whether index has been created
if (!mlm->_se)
Expand Down Expand Up @@ -553,7 +559,7 @@ namespace dd
}

// build index
if (ad_in.has("build_index") && ad_in.get("build_index").get<bool>())
if (output_params->build_index)
{
if (mlm->_se)
mlm->build_index();
Expand All @@ -562,7 +568,7 @@ namespace dd
}

// search
if (ad_in.has("search") && ad_in.get("search").get<bool>())
if (output_params->search)
{
// check whether index has been created
if (!mlm->_se)
Expand All @@ -582,11 +588,11 @@ namespace dd
int search_nn = _best;
if (has_roi)
search_nn = _search_nn;
if (ad_in.has("search_nn"))
search_nn = ad_in.get("search_nn").get<int>();
if (output_params->search_nn)
search_nn = output_params->search_nn;
#ifdef USE_FAISS
if (ad_in.has("nprobe"))
mlm->_se->_tse->_nprobe = ad_in.get("nprobe").get<int>();
if (output_params->nprobe)
mlm->_se->_tse->_nprobe = output_params->nprobe;
#endif
if (!has_roi)
{
Expand Down

0 comments on commit 7973b92

Please sign in to comment.