Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 43 additions & 21 deletions onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1387,9 +1387,25 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
}

// Find inputs and outputs of the subgraph

std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs;

// These maps store the inputs and outpus of the subgraph.
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
// to determine the final inputs and outputs of the subgraph.
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;

// This map stores the node's output that will be consumed by another node outside of this subgraph.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;

// This map stores the node's output that is original graph's output.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;

std::unordered_set<const NodeArg*> erased;

int input_order = 0;
int output_order = 0;

Expand All @@ -1408,7 +1424,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -1423,7 +1439,7 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand Down Expand Up @@ -1451,32 +1467,38 @@ std::unique_ptr<IndexedSubGraph> NvExecutionProvider::GetSubGraph(SubGraph_t gra
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order});
}
fused_outputs[output] = output_order++;
fused_outputs.insert({output, output_order++});
}
} else {
fused_outputs_to_add[output] = output_order++;
// This output will be consumed by another node outside of this subgraph.
// So the output should be put into the subgraph's output list.
fused_outputs_to_add.insert({output, output_order++});
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
}

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
// Only when output is neither in input list nor erased list,
// and the output is consumed by another node, add the output to output list
fused_outputs.insert({output, output_order++});
}
}

if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order++});
}
}
}

Expand Down
64 changes: 43 additions & 21 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2035,9 +2035,25 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
}

// Find inputs and outputs of the subgraph

std::unique_ptr<IndexedSubGraph> sub_graph = onnxruntime::IndexedSubGraph::Create();
std::unordered_map<const NodeArg*, int> original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add;
std::unordered_map<const NodeArg*, int> original_inputs;

// These maps store the inputs and outpus of the subgraph.
// Please note that the inputs and outputs of the maps will be dynamically updated during node iteration
// to determine the final inputs and outputs of the subgraph.
std::unordered_map<const NodeArg*, int> fused_inputs, fused_outputs;

// This map stores the node's output that will be consumed by another node outside of this subgraph.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> fused_outputs_to_add;

// This map stores the node's output that is original graph's output.
// So the node's output should be put into the subgraph's output list.
std::unordered_map<const NodeArg*, int> graph_outputs_to_add;

std::unordered_set<const NodeArg*> erased;

int input_order = 0;
int output_order = 0;

Expand All @@ -2056,7 +2072,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand All @@ -2071,7 +2087,7 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(input);
} else if (erased.find(input) == erased.end()) {
// Only when input is neither in output list nor erased list, add the input to input list
fused_inputs[input] = input_order++;
fused_inputs.insert({input, input_order++});
}
}

Expand Down Expand Up @@ -2099,32 +2115,38 @@ std::unique_ptr<IndexedSubGraph> TensorrtExecutionProvider::GetSubGraph(SubGraph
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order});
}
fused_outputs[output] = output_order++;
fused_outputs.insert({output, output_order++});
}
} else {
fused_outputs_to_add[output] = output_order++;
// This output will be consumed by another node outside of this subgraph.
// So the output should be put into the subgraph's output list.
fused_outputs_to_add.insert({output, output_order++});
}
}
} else {
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
}
// Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list
else if (erased.find(output) == erased.end()) {
if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
graph_outputs_to_add[output] = output_order;
}
}

if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
fused_outputs[output] = output_order++;
}
for (const auto& output : node->OutputDefs()) {
const auto& it = fused_inputs.find(output);
if (it != fused_inputs.end()) {
fused_inputs.erase(it);
erased.insert(output);
} else if (erased.find(output) == erased.end()) {
if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) {
// Only when output is neither in input list nor erased list,
// and the output is consumed by another node, add the output to output list
fused_outputs.insert({output, output_order++});
}
}

if (graph_output_names.find(output->Name()) != graph_output_names.end()) {
// This output is the graph's output.
// So the output should be put into the subgraph's output list.
graph_outputs_to_add.insert({output, output_order++});
}
}
}

Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,56 @@ TEST_P(TypeTests, IOTypes) {
}
}

TEST(NvExecutionProviderTest, TestSessionOutputs) {
/*
* Model #1:
*
* "input" ---> TopK ---
* |---> "scores"
* |--- Less ---> "Less_output_0"
* |--- Div ---> "Div_output_0"
* |--- Mod ---> "labels"
*/
{
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});

auto model_path = ORT_TSTR("model_with_topk_and_multiple_graph_outputs.onnx");
Ort::Status status(CreateModelWithTopKWhichContainsGraphOutput(model_path));
ASSERT_TRUE(status.IsOK());

Ort::Session session(env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 4);
}

/*
* Model #2:
*
* "X" ---> Dropout ---> MatMul ---> "Y"
* ^ |
* | |
* "W" ------ ----> Can't be graph's output
*
*/
{
Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"};
Ort::SessionOptions session_options;
session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {});

auto model_path = ORT_TSTR("model_with_node_output_not_used.onnx");
Ort::Status status(CreateModelWithNodeOutputNotUsed(model_path));
ASSERT_TRUE(status.IsOK());

Ort::Session session(env, model_path, session_options);

size_t output_count = session.GetOutputCount();
ASSERT_TRUE(output_count == 1);
}
}

INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests,
::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
Expand Down
Loading
Loading