Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ struct ShutdownProtobuf {

namespace onnxruntime {

// Helper function to check if a data type is supported by input output nodes ofNvTensorRTRTX EP
static bool IsSupportedInputOutputDataType(ONNXTensorElementDataType data_type) {
switch (data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: // kFLOAT - 32-bit floating point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: // kHALF - IEEE 16-bit floating-point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // kBF16 - Brain float 16
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: // kBOOL - 8-bit boolean
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: // kINT4 - 4-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: // kINT8 - 8-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // kUINT8 - 8-bit unsigned integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
return true;
default:
return false;
}
}

// Helper function to check if a data type is supported by NvTensorRTRTX EP
static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
switch (data_type) {
Expand All @@ -98,6 +117,7 @@ static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: // kINT32 - 32-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: // kINT64 - 64-bit signed integer
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: // kFP8 - 8-bit floating point
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: // kDOUBLE - 64-bit floating point
return true;
default:
return false;
Expand Down Expand Up @@ -1939,6 +1959,28 @@ NvExecutionProvider::GetCapability(const GraphViewer& graph,
#endif
model_path_[sizeof(model_path_) - 1] = '\0';

// Early return if the model has unsupported input/output data types
for (const auto* input : graph.GetInputs()) {
const auto* tp = input->TypeAsProto();
if (tp && tp->has_tensor_type()) {
auto data_type = static_cast<ONNXTensorElementDataType>(tp->tensor_type().elem_type());
if (!IsSupportedInputOutputDataType(data_type)) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unsupported data type " << GetDataTypeName(data_type) << " for input node: " << input->Name();
return result;
}
}
}
for (const auto* output : graph.GetOutputs()) {
const auto* tp = output->TypeAsProto();
if (tp && tp->has_tensor_type()) {
auto data_type = static_cast<ONNXTensorElementDataType>(tp->tensor_type().elem_type());
if (!IsSupportedInputOutputDataType(data_type)) {
LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unsupported data type " << GetDataTypeName(data_type) << " for output node: " << output->Name();
return result;
}
}
}

const int number_of_ort_nodes = graph.NumberOfNodes();
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);

Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/test/perftest/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,27 @@ Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::Performan
std::unordered_map<std::string, std::string> provider_options;
session_options.AppendExecutionProvider(provider_name, provider_options);

// free dim override
if (!test_config.run_config.free_dim_name_overrides.empty()) {
for (auto const& dim_override : test_config.run_config.free_dim_name_overrides) {
if (g_ort->AddFreeDimensionOverrideByName(session_options, ToUTF8String(dim_override.first).c_str(), dim_override.second) != nullptr) {
fprintf(stderr, "AddFreeDimensionOverrideByName failed for named dimension: %s\n", ToUTF8String(dim_override.first).c_str());
} else {
fprintf(stdout, "Overriding dimension with name, %s, to %d\n", ToUTF8String(dim_override.first).c_str(), (int)dim_override.second);
}
}
}

if (!test_config.run_config.free_dim_denotation_overrides.empty()) {
for (auto const& dim_override : test_config.run_config.free_dim_denotation_overrides) {
if (g_ort->AddFreeDimensionOverride(session_options, ToUTF8String(dim_override.first).c_str(), dim_override.second) != nullptr) {
fprintf(stderr, "AddFreeDimensionOverride failed for dimension denotation: %s\n", ToUTF8String(dim_override.first).c_str());
} else {
fprintf(stdout, "Overriding dimension with denotation, %s, to %d\n", ToUTF8String(dim_override.first).c_str(), (int)dim_override.second);
}
}
}

Ort::ModelCompilationOptions model_compile_options(env, session_options);
model_compile_options.SetEpContextEmbedMode(test_config.run_config.compile_binary_embed);
model_compile_options.SetInputModelPath(test_config.model_info.model_file_path.c_str());
Expand Down
Loading