Skip to content

Commit

Permalink
Use get_js_infer_result helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Retribution98 committed Jan 29, 2025
1 parent ccaa98e commit ccfeeb9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 42 deletions.
2 changes: 2 additions & 0 deletions src/bindings/js/node/include/infer_request.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,6 @@ class InferRequestWrap : public Napi::ObjectWrap<InferRequestWrap> {

void FinalizerCallback(Napi::Env env, void* finalizeData, TsfnContext* context);

std::map<std::string, ov::Tensor> get_js_infer_result(ov::InferRequest* infer_request);

void performInferenceThread(TsfnContext* context);
51 changes: 24 additions & 27 deletions src/bindings/js/node/src/infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,33 @@ Napi::Value InferRequestWrap::get_output_tensor(const Napi::CallbackInfo& info)
return TensorWrap::wrap(info.Env(), tensor);
}

Napi::Value InferRequestWrap::get_output_tensors(const Napi::CallbackInfo& info) {
auto model_outputs = _infer_request.get_compiled_model().outputs();
auto outputs_obj = Napi::Object::New(info.Env());

std::map<std::string, ov::Tensor> get_js_infer_result(ov::InferRequest* infer_request) {
auto model_outputs = infer_request->get_compiled_model().outputs();
std::map<std::string, ov::Tensor> outputs;
for (auto& output : model_outputs) {
auto tensor = _infer_request.get_tensor(output);
const auto& tensor = infer_request->get_tensor(output);
auto new_tensor = ov::Tensor(tensor.get_element_type(), tensor.get_shape());
tensor.copy_to(new_tensor);
std::string name;
if (output.get_names().empty()) {
name = output.get_node()->get_name();
} else {
name = output.get_any_name();
const auto name = output.get_names().empty() ? output.get_node()->get_name() : output.get_any_name();

auto key = name;
int counter = 1;
while (outputs.find(key) != outputs.end()) {
key = name + std::to_string(counter);
++counter;
}
outputs_obj.Set(name, TensorWrap::wrap(info.Env(), new_tensor));

outputs.insert({key, new_tensor});
}
return outputs;
}

Napi::Value InferRequestWrap::get_output_tensors(const Napi::CallbackInfo& info) {
auto output_map = get_js_infer_result(&_infer_request);
auto outputs_obj = Napi::Object::New(info.Env());

for (const auto& [key, tensor] : output_map) {
outputs_obj.Set(key, TensorWrap::wrap(info.Env(), tensor));
}
return outputs_obj;
}
Expand Down Expand Up @@ -215,22 +227,7 @@ void performInferenceThread(TsfnContext* context) {
context->_ir->infer();

auto model_outputs = context->_ir->get_compiled_model().outputs();
std::map<std::string, ov::Tensor> outputs;

for (auto& output : model_outputs) {
const auto& tensor = context->_ir->get_tensor(output);
auto new_tensor = ov::Tensor(tensor.get_element_type(), tensor.get_shape());
tensor.copy_to(new_tensor);
std::string name;
if (output.get_names().empty()) {
name = output.get_node()->get_name();
} else {
name = output.get_any_name();
}
outputs.insert({name, new_tensor});
}

context->result = outputs;
context->result = get_js_infer_result(context->_ir);
}

auto callback = [](Napi::Env env, Napi::Function, TsfnContext* context) {
Expand Down
17 changes: 2 additions & 15 deletions src/bindings/js/node/src/node_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,7 @@ Napi::Value Output<ov::Node>::get_partial_shape(const Napi::CallbackInfo& info)
}

Napi::Value Output<ov::Node>::get_any_name(const Napi::CallbackInfo& info) {
std::string name;
if (_output.get_names().empty()) {
name = _output.get_node()->get_name();
} else {
name = _output.get_any_name();
}

return Napi::String::New(info.Env(), name);
return Napi::String::New(info.Env(), _output.get_node()->get_name());
}

Output<const ov::Node>::Output(const Napi::CallbackInfo& info)
Expand Down Expand Up @@ -95,11 +88,5 @@ Napi::Value Output<const ov::Node>::get_partial_shape(const Napi::CallbackInfo&
}

Napi::Value Output<const ov::Node>::get_any_name(const Napi::CallbackInfo& info) {
std::string name;
if (_output.get_names().empty()) {
name = _output.get_node()->get_name();
} else {
name = _output.get_any_name();
}
return Napi::String::New(info.Env(), name);
return Napi::String::New(info.Env(), _output.get_node()->get_name());
}

0 comments on commit ccfeeb9

Please sign in to comment.