diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 62210d65848d1..3c6f6c9b82c83 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1387,9 +1387,25 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outpus of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + int input_order = 0; int output_order = 0; @@ -1408,7 +1424,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1423,7 +1439,7 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -1451,32 +1467,38 @@ std::unique_ptr NvExecutionProvider::GetSubGraph(SubGraph_t gra erased.insert(output); } else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order}); } - fused_outputs[output] = output_order++; + fused_outputs.insert({output, output_order++}); } } else { - fused_outputs_to_add[output] = output_order++; + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index cd0c0e4bffdb5..932253ff1b634 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2035,9 +2035,25 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph } // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); - std::unordered_map original_inputs, fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_map original_inputs; + + // These maps store the inputs and outpus of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map fused_inputs, fused_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map fused_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + std::unordered_set erased; + int input_order = 0; int output_order = 0; @@ -2056,7 +2072,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2071,7 +2087,7 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(input); } else if (erased.find(input) == erased.end()) { // Only when input is neither in output list nor erased list, add the input to input list - fused_inputs[input] = input_order++; + fused_inputs.insert({input, input_order++}); } } @@ -2099,32 +2115,38 @@ std::unique_ptr TensorrtExecutionProvider::GetSubGraph(SubGraph erased.insert(output); } else if (erased.find(output) == erased.end()) { if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order}); } - fused_outputs[output] = output_order++; + fused_outputs.insert({output, output_order++}); } } else { - fused_outputs_to_add[output] = output_order++; + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + fused_outputs_to_add.insert({output, output_order++}); } } - } else { - for (const auto& output : node->OutputDefs()) { - const auto& it = fused_inputs.find(output); - if (it != fused_inputs.end()) { - fused_inputs.erase(it); - erased.insert(output); - } - // Only when output is neither in input list nor erased list, and the output is consumed by another node, add the output to output list - else if (erased.find(output) == erased.end()) { - if (graph_output_names.find(output->Name()) != graph_output_names.end()) { - graph_outputs_to_add[output] = output_order; - } + } - if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { - fused_outputs[output] = output_order++; - } + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (graph.GetGraph().GetConsumerNodes(output->Name()).size() > 0) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + fused_outputs.insert({output, output_order++}); } } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } } } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index d8cc56d738175..72b275c14ea3b 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -203,6 +203,56 @@ TEST_P(TypeTests, IOTypes) { } } +TEST(NvExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("model_with_topk_and_multiple_graph_outputs.onnx"); + Ort::Status status(CreateModelWithTopKWhichContainsGraphOutput(model_path)); + ASSERT_TRUE(status.IsOK()); + + Ort::Session session(env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {}); + + auto model_path = ORT_TSTR("model_with_node_output_not_used.onnx"); + Ort::Status status(CreateModelWithNodeOutputNotUsed(model_path)); + ASSERT_TRUE(status.IsOK()); + + Ort::Session session(env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} + INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ::testing::Values(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index 3a91fc1ba09bb..79cc2fef42666 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -19,6 +19,14 @@ #include "test/common/trt_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/unittest_util/framework_test_utils.h" +#include +#include +#include +#include +#include + +#define ONNX_IR_VERSION 11 +#define OPSET_VERSION 23 namespace onnxruntime { namespace test { @@ -464,5 +472,266 @@ Ort::IoBinding generate_io_binding( return binding; } +// Helper: make ONNX_NAMESPACE::TensorProto +ONNX_NAMESPACE::TensorProto MakeTensor(const std::string& name, + ONNX_NAMESPACE::TensorProto::DataType dtype, + const std::vector& dims, + const std::vector& vals) { + ONNX_NAMESPACE::TensorProto t; + t.set_name(name); + t.set_data_type(dtype); + for (auto d : dims) t.add_dims(d); + for (auto v : vals) t.add_int64_data(v); + return t; +} + +ONNX_NAMESPACE::TensorProto MakeTensorFloat(const std::string& name, + const std::vector& dims, + const std::vector& vals) { + ONNX_NAMESPACE::TensorProto t; + t.set_name(name); + t.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); + for (auto d : dims) t.add_dims(d); + for (auto v : vals) t.add_float_data(v); + return t; +} + +OrtStatus* CreateModelWithNodeOutputNotUsed(const PathString& model_name) { + // -------------------- + // Create Model + // -------------------- + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); // empty = default ONNX domain + opset->set_version(OPSET_VERSION); + + ONNX_NAMESPACE::GraphProto* graph = model.mutable_graph(); + graph->set_name("DropoutMatMulGraph"); + + // -------------------- + // Create Inputs + // X: [3, 2] + // W: [2, 3] + // -------------------- + { + auto* x = graph->add_input(); + x->set_name("X"); + + auto* type = x->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(3); + shape->add_dim()->set_dim_value(2); + } + + { + auto* w = graph->add_input(); + w->set_name("W"); + + auto* type = w->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(2); + shape->add_dim()->set_dim_value(3); + } + + // -------------------- + // Output Y: [2, 3] + // -------------------- + { + auto* x = graph->add_output(); + x->set_name("Y"); + + auto* type = x->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(2); + shape->add_dim()->set_dim_value(3); + } + + // -------------------- + // Dropout Node + // -------------------- + { + ONNX_NAMESPACE::NodeProto* node = graph->add_node(); + node->set_name("DropoutNode"); + node->set_op_type("Dropout"); + + node->add_input("X"); + node->add_output("dropout_out"); + node->add_output("dropout_mask"); + } + + // -------------------- + // MatMul Node + // -------------------- + { + ONNX_NAMESPACE::NodeProto* node = graph->add_node(); + node->set_name("MatMulNode"); + node->set_op_type("MatMul"); + + node->add_input("dropout_out"); + node->add_input("W"); + node->add_output("Y"); + } + + // -------------------- + // Validate + // -------------------- + try { + onnx::checker::check_model(model); + } catch (const std::exception& ex) { + std::string error_msg = "Model validation failed: " + std::string(ex.what()); + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, error_msg.c_str()); + } + + std::ofstream ofs(model_name, std::ios::binary); + if (!model.SerializeToOstream(&ofs)) { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "Failed to write model"); + } + + return nullptr; +} + +OrtStatus* CreateModelWithTopKWhichContainsGraphOutput(const PathString& model_name) { + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); // empty = default ONNX domain + opset->set_version(OPSET_VERSION); + + auto* graph = model.mutable_graph(); + graph->set_name("TopKGraph"); + + // ====================== + // ---- Model Input ---- + // ====================== + { + auto* inp = graph->add_input(); + inp->set_name("input"); + + auto* type = inp->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + // Shape: ["N"] + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_param("N"); + } + + // ====================== + // ---- Initializers ---- + // ====================== + { + // K = [300] + ONNX_NAMESPACE::TensorProto K = MakeTensor("K", ONNX_NAMESPACE::TensorProto::INT64, {1}, {300}); + *graph->add_initializer() = K; + + // zero = 0.0 (scalar) + ONNX_NAMESPACE::TensorProto zero = MakeTensor("zero", ONNX_NAMESPACE::TensorProto::INT64, {}, {0}); + *graph->add_initializer() = zero; + + // twenty_six = 26 (scalar) + ONNX_NAMESPACE::TensorProto ts = MakeTensor("twenty_six", ONNX_NAMESPACE::TensorProto::INT64, {}, {26}); + *graph->add_initializer() = ts; + } + + // ====================== + // ---- TopK ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("TopK"); + n->add_input("input"); + n->add_input("K"); + n->add_output("scores"); + n->add_output("topk_indices"); + n->set_name("TopK"); + } + + // ====================== + // ---- Less ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Less"); + n->add_input("topk_indices"); + n->add_input("zero"); + n->add_output("Less_output_0"); + n->set_name("Less"); + } + + // ====================== + // ---- Div ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Div"); + n->add_input("topk_indices"); + n->add_input("twenty_six"); + n->add_output("Div_17_output_0"); + n->set_name("Div"); + } + + // ====================== + // ---- Mod ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Mod"); + n->add_input("topk_indices"); + n->add_input("twenty_six"); + n->add_output("labels"); + n->set_name("Mod"); + } + + // ========================= + // ---- Graph Outputs ---- + // ========================= + auto add_output = [&](const std::string& name, ONNX_NAMESPACE::TensorProto::DataType type, const std::string& dim) { + auto* out = graph->add_output(); + out->set_name(name); + + auto* tt = out->mutable_type()->mutable_tensor_type(); + tt->set_elem_type(type); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_param(dim); + }; + + add_output("scores", ONNX_NAMESPACE::TensorProto::FLOAT, "K"); + add_output("Less_output_0", ONNX_NAMESPACE::TensorProto::BOOL, "K"); + add_output("Div_17_output_0", ONNX_NAMESPACE::TensorProto::INT64, "K"); + add_output("labels", ONNX_NAMESPACE::TensorProto::INT64, "K"); + + // ====================== + // Validate + Save + // ====================== + try { + onnx::checker::check_model(model); + } catch (const std::exception& e) { + std::string error_msg = "Model validation failed: " + std::string(e.what()); + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, error_msg.c_str()); + } + + std::ofstream ofs(model_name, std::ios::binary); + if (!model.SerializeToOstream(&ofs)) { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "Failed to write model"); + } + + return nullptr; +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h index 0f011af8211ca..11000a76ad2cd 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h @@ -119,6 +119,9 @@ void CreateBaseModel(const PathString& model_name, void CreateLargeLLMModel(const PathString& model_path, const PathString& external_data_path); +OrtStatus* CreateModelWithTopKWhichContainsGraphOutput(const PathString& model_name); +OrtStatus* CreateModelWithNodeOutputNotUsed(const PathString& model_name); + Ort::IoBinding generate_io_binding( Ort::Session& session, std::map> shape_overwrites = {}, diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index a746493d779f8..a0149613c21c3 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -1,5 +1,7 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" +#include "tensorrt_test_utils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/inference_session.h" #include "test/providers/provider_test_utils.h" @@ -1358,5 +1360,57 @@ TEST(TensorrtExecutionProviderTest, RemoveCycleTest) { ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } + +TEST(TensorrtExecutionProviderTest, TestSessionOutputs) { + /* + * Model #1: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("model_with_topk_and_multiple_graph_outputs.onnx"); + Ort::Status status(CreateModelWithTopKWhichContainsGraphOutput(model_path)); + ASSERT_TRUE(status.IsOK()); + + Ort::Session session(env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 4); + } + + /* + * Model #2: + * + * "X" ---> Dropout ---> MatMul ---> "Y" + * ^ | + * | | + * "W" ------ ----> Can't be graph's output + * + */ + { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "test"}; + OrtTensorRTProviderOptionsV2 provider_options; + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_TensorRT_V2(provider_options); + + auto model_path = ORT_TSTR("model_with_node_output_not_used.onnx"); + Ort::Status status(CreateModelWithNodeOutputNotUsed(model_path)); + ASSERT_TRUE(status.IsOK()); + + Ort::Session session(env, model_path, session_options); + + size_t output_count = session.GetOutputCount(); + ASSERT_TRUE(output_count == 1); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.cc b/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.cc new file mode 100644 index 0000000000000..51738342c6cb0 --- /dev/null +++ b/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.cc @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "onnxruntime_cxx_api.h" +#include "core/common/path_string.h" + +#include +#include + +#include +#include +#include + +#define ONNX_IR_VERSION 11 +#define OPSET_VERSION 23 + +namespace onnxruntime { +namespace test { + +// Helper: make ONNX_NAMESPACE::TensorProto +ONNX_NAMESPACE::TensorProto MakeTensor(const std::string& name, + ONNX_NAMESPACE::TensorProto::DataType dtype, + const std::vector& dims, + const std::vector& vals) { + ONNX_NAMESPACE::TensorProto t; + t.set_name(name); + t.set_data_type(dtype); + for (auto d : dims) t.add_dims(d); + for (auto v : vals) t.add_int64_data(v); + return t; +} + +ONNX_NAMESPACE::TensorProto MakeTensorFloat(const std::string& name, + const std::vector& dims, + const std::vector& vals) { + ONNX_NAMESPACE::TensorProto t; + t.set_name(name); + t.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); + for (auto d : dims) t.add_dims(d); + for (auto v : vals) t.add_float_data(v); + return t; +} + +OrtStatus* CreateModelWithNodeOutputNotUsed(const PathString& model_name) { + // -------------------- + // Create Model + // -------------------- + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); // empty = default ONNX domain + opset->set_version(OPSET_VERSION); + + ONNX_NAMESPACE::GraphProto* graph = model.mutable_graph(); + graph->set_name("DropoutMatMulGraph"); + + // -------------------- + // Create Inputs + // X: [3, 2] + // W: [2, 3] + // -------------------- + { + auto* x = graph->add_input(); + x->set_name("X"); + + auto* type = x->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(3); + shape->add_dim()->set_dim_value(2); + } + + { + auto* w = graph->add_input(); + w->set_name("W"); + + auto* type = w->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(2); + shape->add_dim()->set_dim_value(3); + } + + // -------------------- + // Output Y: [2, 3] + // -------------------- + { + auto* x = graph->add_output(); + x->set_name("Y"); + + auto* type = x->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_value(2); + shape->add_dim()->set_dim_value(3); + } + + // -------------------- + // Dropout Node + // -------------------- + { + ONNX_NAMESPACE::NodeProto* node = graph->add_node(); + node->set_name("DropoutNode"); + node->set_op_type("Dropout"); + + node->add_input("X"); + node->add_output("dropout_out"); + node->add_output("dropout_mask"); + } + + // -------------------- + // MatMul Node + // -------------------- + { + ONNX_NAMESPACE::NodeProto* node = graph->add_node(); + node->set_name("MatMulNode"); + node->set_op_type("MatMul"); + + node->add_input("dropout_out"); + node->add_input("W"); + node->add_output("Y"); + } + + // -------------------- + // Validate + // -------------------- + try { + onnx::checker::check_model(model); + } catch (const std::exception& ex) { + std::string error_msg = "Model validation failed: " + std::string(ex.what()); + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, error_msg.c_str()); + } + + std::ofstream ofs(model_name, std::ios::binary); + if (!model.SerializeToOstream(&ofs)) { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "Failed to write model"); + } + + return nullptr; +} + +OrtStatus* CreateModelWithTopKWhichContainsGraphOutput(const PathString& model_name) { + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); // empty = default ONNX domain + opset->set_version(OPSET_VERSION); + + auto* graph = model.mutable_graph(); + graph->set_name("TopKGraph"); + + // ====================== + // ---- Model Input ---- + // ====================== + { + auto* inp = graph->add_input(); + inp->set_name("input"); + + auto* type = inp->mutable_type(); + auto* tt = type->mutable_tensor_type(); + tt->set_elem_type(ONNX_NAMESPACE::TensorProto::FLOAT); + + // Shape: ["N"] + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_param("N"); + } + + // ====================== + // ---- Initializers ---- + // ====================== + { + // K = [300] + ONNX_NAMESPACE::TensorProto K = MakeTensor("K", ONNX_NAMESPACE::TensorProto::INT64, {1}, {300}); + *graph->add_initializer() = K; + + // zero = 0.0 (scalar) + ONNX_NAMESPACE::TensorProto zero = MakeTensor("zero", ONNX_NAMESPACE::TensorProto::INT64, {}, {0}); + *graph->add_initializer() = zero; + + // twenty_six = 26 (scalar) + ONNX_NAMESPACE::TensorProto ts = MakeTensor("twenty_six", ONNX_NAMESPACE::TensorProto::INT64, {}, {26}); + *graph->add_initializer() = ts; + } + + // ====================== + // ---- TopK ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("TopK"); + n->add_input("input"); + n->add_input("K"); + n->add_output("scores"); + n->add_output("topk_indices"); + n->set_name("TopK"); + } + + // ====================== + // ---- Less ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Less"); + n->add_input("topk_indices"); + n->add_input("zero"); + n->add_output("Less_output_0"); + n->set_name("Less"); + } + + // ====================== + // ---- Div ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Div"); + n->add_input("topk_indices"); + n->add_input("twenty_six"); + n->add_output("Div_17_output_0"); + n->set_name("Div"); + } + + // ====================== + // ---- Mod ---- + // ====================== + { + ONNX_NAMESPACE::NodeProto* n = graph->add_node(); + n->set_op_type("Mod"); + n->add_input("topk_indices"); + n->add_input("twenty_six"); + n->add_output("labels"); + n->set_name("Mod"); + } + + // ========================= + // ---- Graph Outputs ---- + // ========================= + auto add_output = [&](const std::string& name, ONNX_NAMESPACE::TensorProto::DataType type, const std::string& dim) { + auto* out = graph->add_output(); + out->set_name(name); + + auto* tt = out->mutable_type()->mutable_tensor_type(); + tt->set_elem_type(type); + + auto* shape = tt->mutable_shape(); + shape->add_dim()->set_dim_param(dim); + }; + + add_output("scores", ONNX_NAMESPACE::TensorProto::FLOAT, "K"); + add_output("Less_output_0", ONNX_NAMESPACE::TensorProto::BOOL, "K"); + add_output("Div_17_output_0", ONNX_NAMESPACE::TensorProto::INT64, "K"); + add_output("labels", ONNX_NAMESPACE::TensorProto::INT64, "K"); + + // ====================== + // Validate + Save + // ====================== + try { + onnx::checker::check_model(model); + } catch (const std::exception& e) { + std::string error_msg = "Model validation failed: " + std::string(e.what()); + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, error_msg.c_str()); + } + + std::ofstream ofs(model_name, std::ios::binary); + if (!model.SerializeToOstream(&ofs)) { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + return ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, "Failed to write model"); + } + + return nullptr; +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.h b/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.h new file mode 100644 index 0000000000000..3857d8ee80131 --- /dev/null +++ b/onnxruntime/test/providers/tensorrt/tensorrt_test_utils.h @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "onnxruntime_cxx_api.h" +#include "core/common/path_string.h" + +namespace onnxruntime { +namespace test { +OrtStatus* CreateModelWithTopKWhichContainsGraphOutput(const PathString& model_name); +OrtStatus* CreateModelWithNodeOutputNotUsed(const PathString& model_name); +} // namespace test +} // namespace onnxruntime