diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f2ba19250af3f..415d7a5756b07 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -25,6 +25,9 @@ endif() set(CMAKE_C_STANDARD 99) include(CheckCXXCompilerFlag) +if(NOT ANDROID) + include(CheckIncludeFile) +endif() include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) @@ -734,6 +737,10 @@ if (WIN32) endif() else() + if(NOT ANDROID) + # On ANDROID it requires ANDROID API level >=28 + check_include_file("glob.h" HAVE_GLOB_H) + endif() check_cxx_compiler_flag(-Wambiguous-reversed-operator HAS_AMBIGUOUS_REVERSED_OPERATOR) check_cxx_compiler_flag(-Wbitwise-instead-of-logical HAS_BITWISE_INSTEAD_OF_LOGICAL) check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index f82a23bf4026b..fe038e8baac96 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once - +#cmakedefine HAVE_GLOB_H #cmakedefine HAS_BITWISE_INSTEAD_OF_LOGICAL #cmakedefine HAS_CAST_FUNCTION_TYPE #cmakedefine HAS_CATCH_VALUE diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 789647673e782..78e13f5201c5b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -808,7 +808,7 @@ if(NOT IOS) onnxruntime_add_include_to_target(onnx_test_runner_common onnxruntime_common onnxruntime_framework onnxruntime_test_utils onnx onnx_proto re2::re2 flatbuffers::flatbuffers Boost::mp11 safeint_interface) - add_dependencies(onnx_test_runner_common onnx_test_data_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) + add_dependencies(onnx_test_runner_common ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnx_test_runner_common PRIVATE ${eigen_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) @@ -906,7 +906,6 @@ AddTest( SOURCES ${all_tests} ${onnxruntime_unittest_main_src} LIBS ${onnx_test_runner_common_lib} ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} - onnx_test_data_proto DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) @@ -1007,23 +1006,6 @@ endif() set(test_data_target onnxruntime_test_all) -onnxruntime_add_static_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto) -add_dependencies(onnx_test_data_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) -#onnx_proto target should mark this definition as public, instead of private -target_compile_definitions(onnx_test_data_proto PRIVATE "-DONNX_API=") -onnxruntime_add_include_to_target(onnx_test_data_proto onnx_proto) -if (MSVC) - # Cutlass code has an issue with the following: - # warning C4100: 'magic': unreferenced formal parameter - target_compile_options(onnx_test_data_proto PRIVATE "/wd4100") -endif() -target_include_directories(onnx_test_data_proto PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) -set_target_properties(onnx_test_data_proto PROPERTIES FOLDER "ONNXRuntimeTest") -if(NOT DEFINED onnx_SOURCE_DIR) - find_path(onnx_SOURCE_DIR NAMES "onnx/onnx-ml.proto3" "onnx/onnx-ml.proto" REQUIRED) -endif() -onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${onnx_SOURCE_DIR} TARGET onnx_test_data_proto) - # # onnxruntime_ir_graph test data # @@ -1101,7 +1083,6 @@ endif() set(onnx_test_libs onnxruntime_test_utils ${ONNXRUNTIME_TEST_LIBS} - onnx_test_data_proto ${onnxruntime_EXTERNAL_LIBRARIES}) if (NOT IOS) @@ -1201,7 +1182,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_compile_options(onnx_test_runner_common PRIVATE -D_CRT_SECURE_NO_WARNINGS) endif() - if (NOT onnxruntime_REDUCED_OPS_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (NOT onnxruntime_REDUCED_OPS_BUILD AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnx_SOURCE_DIR) add_test(NAME onnx_test_pytorch_converted COMMAND onnx_test_runner ${onnx_SOURCE_DIR}/onnx/backend/test/data/pytorch-converted) add_test(NAME onnx_test_pytorch_operator @@ -1258,7 +1239,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. set(onnxruntime_perf_test_libs onnx_test_runner_common onnxruntime_test_utils onnxruntime_common - onnxruntime onnxruntime_flatbuffers onnx_test_data_proto + onnxruntime onnxruntime_flatbuffers ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) if(NOT WIN32) @@ -1308,7 +1289,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(WIN32) target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) endif() - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() diff --git a/onnxruntime/test/ir/onnx_model_test.cc b/onnxruntime/test/ir/onnx_model_test.cc index 9327d86966981..55fc4f42bec64 100644 --- a/onnxruntime/test/ir/onnx_model_test.cc +++ b/onnxruntime/test/ir/onnx_model_test.cc @@ -26,44 +26,6 @@ class ONNXModelsTest : public ::testing::Test { std::unique_ptr logger_; }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -// Tests that Resolve() properly clears the state of topological sorted nodes, -// inputs, outputs and valueInfo. -// Assumes the graph passed in has been previously resolved. -static void TestResolve(onnxruntime::Graph& graph) { - GraphViewer graph_viewer(graph); - auto& nodes_before = graph_viewer.GetNodesInTopologicalOrder(); - auto& inputs_before = graph.GetInputs(); - auto& outputs_before = graph.GetOutputs(); - auto& value_info_before = graph.GetValueInfo(); - - // Touch the graph to force Resolve() to recompute. - graph.SetGraphResolveNeeded(); - graph.SetGraphProtoSyncNeeded(); - ASSERT_STATUS_OK(graph.Resolve()); - - GraphViewer graph_viewer_2(graph); - auto& nodes_after = graph_viewer_2.GetNodesInTopologicalOrder(); - auto& inputs_after = graph.GetInputs(); - auto& outputs_after = graph.GetOutputs(); - auto& value_info_after = graph.GetValueInfo(); - - // Multiple calls to Resolve() should not alter the sorted nodes, - // inputs, outputs and valueInfo. The internal state should be - // cleared. - EXPECT_EQ(nodes_before, nodes_after); - EXPECT_EQ(inputs_before, inputs_after); - EXPECT_EQ(outputs_before, outputs_after); - EXPECT_EQ(value_info_before, value_info_after); -} - -TEST_F(ONNXModelsTest, squeeze_net) { - // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"), model, nullptr, *logger_)); - TestResolve(model->MainGraph()); -} -#endif TEST_F(ONNXModelsTest, non_existing_model) { // NOTE: this requires the current directory to be where onnxruntime_ir_UT.exe is located @@ -96,76 +58,6 @@ class ONNXModelsTest1 : public ::testing::TestWithParam { return oss.str(); } }; -#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS -TEST_F(ONNXModelsTest, bvlc_alexnet_1) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - // Allows protobuf library versions < 3.2.0 to parse messages greater than 64MB. - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - EXPECT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ORT_TSTR("../models/opset8/test_bvlc_alexnet/model.onnx"), model, nullptr, - *logger_)); - - // Check the graph input/output/value_info should have the same size as specified in the model file. - EXPECT_EQ(static_cast(model_proto.graph().value_info_size()), model->MainGraph().GetValueInfo().size()); - EXPECT_EQ(static_cast(model_proto.graph().input_size()), model->MainGraph().GetInputs().size() + model->MainGraph().GetAllInitializedTensors().size()); - EXPECT_EQ(static_cast(model_proto.graph().output_size()), model->MainGraph().GetOutputs().size()); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromFile) { - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(GetModelFileName(), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -TEST_P(ONNXModelsTest1, LoadFromProtobuf) { - using ::google::protobuf::io::CodedInputStream; - using ::google::protobuf::io::FileInputStream; - using ::google::protobuf::io::ZeroCopyInputStream; - int fd; - ASSERT_STATUS_OK(Env::Default().FileOpenRd(GetModelFileName(), fd)); - ASSERT_TRUE(fd > 0); - std::unique_ptr raw_input(new FileInputStream(fd)); - std::unique_ptr coded_input(new CodedInputStream(raw_input.get())); - coded_input->SetTotalBytesLimit(INT_MAX); - ModelProto model_proto; - bool result = model_proto.ParseFromCodedStream(coded_input.get()); - coded_input.reset(); - raw_input.reset(); - ASSERT_TRUE(result); - ASSERT_STATUS_OK(Env::Default().FileClose(fd)); - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, - *logger_)); - TestResolve(model->MainGraph()); -} - -#ifndef DISABLE_CONTRIB_OPS -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("tiny_yolov2"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#else -INSTANTIATE_TEST_SUITE_P(ONNXModelsTests, - ONNXModelsTest1, - ::testing::Values(ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_googlenet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("bvlc_reference_rcnn_ilsvrc13"), ORT_TSTR("densenet121"), ORT_TSTR("emotion_ferplus"), ORT_TSTR("inception_v1"), ORT_TSTR("inception_v2"), ORT_TSTR("mnist"), ORT_TSTR("resnet50"), ORT_TSTR("shufflenet"), ORT_TSTR("squeezenet"), ORT_TSTR("vgg19"), ORT_TSTR("zfnet512"))); -#endif - -#endif // test a model that conforms to ONNX IR v4 where there are initializers that are not graph inputs. // a NodeArg should be created for all initializers in this case. diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 3433a88515b53..228a31c93d36c 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -4,7 +4,7 @@ // needs to be included first to get around onnxruntime\cmake\external\onnx\onnx/common/constants.h(14): error C2513: 'bool': no variable declared before '=' #include "TestCase.h" - +#include "fnmatch_simple.h" #include #include #include @@ -26,6 +26,7 @@ #include "core/common/common.h" #include "core/platform/env.h" #include +#include "fnmatch_simple.h" #include "core/platform/path_lib.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/allocator.h" @@ -34,115 +35,35 @@ using namespace onnxruntime; using namespace onnxruntime::common; -using google::protobuf::RepeatedPtrField; static constexpr int protobuf_block_size_in_bytes = 4 * 1024 * 1024; const std::string TestModelInfo::unknown_version = "unknown version"; namespace { +using PATH_STRING_TYPE = std::basic_string; -template -inline Ort::Value CreateTensorWithDataAsOrtValue(const Ort::MemoryInfo& info, - OrtAllocator*, - const std::vector& dims, - std::vector& input) { - return Ort::Value::CreateTensor(static_cast(info), input.data(), input.size() * sizeof(T), - dims.data(), dims.size()); -} - -template <> -inline Ort::Value CreateTensorWithDataAsOrtValue(const Ort::MemoryInfo&, - OrtAllocator* allocator, - const std::vector& dims, - std::vector& input) { - auto tensor_value = Ort::Value::CreateTensor(allocator, dims.data(), dims.size(), - ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); - - std::vector p_str; - for (const auto& s : input) { - p_str.push_back(s.c_str()); - } - - tensor_value.FillStringTensor(p_str.data(), p_str.size()); - return tensor_value; -} - -template -Ort::Value PbMapToOrtValue(const google::protobuf::Map& map) { - auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - Ort::AllocatorWithDefaultOptions allocator; - const size_t ele_count = map.size(); - std::vector dims(1, static_cast(ele_count)); - std::vector keys(ele_count); - std::vector values(ele_count); - size_t i = 0; - for (auto& kvp : map) { - keys[i] = kvp.first; - values[i] = kvp.second; - ++i; - } - - //// See helper above - auto ort_keys = CreateTensorWithDataAsOrtValue(info, allocator, dims, keys); - auto ort_values = CreateTensorWithDataAsOrtValue(info, allocator, dims, values); - return Ort::Value::CreateMap(ort_keys, ort_values); -} - -template -Ort::Value VectorProtoToOrtValue(const RepeatedPtrField& input) { - auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - Ort::AllocatorWithDefaultOptions allocator; - std::vector seq; - seq.reserve(input.size()); - for (const T& v : input) { - // create key tensor - const auto& map = v.v(); - size_t ele_count = map.size(); - using key_type = typename std::remove_reference::type::key_type; - using value_type = typename std::remove_reference::type::mapped_type; - std::vector dims(1, static_cast(ele_count)); - std::vector keys(ele_count); - std::vector values(ele_count); - size_t i = 0; - for (auto& kvp : map) { - keys[i] = kvp.first; - values[i] = kvp.second; - ++i; - } - - auto ort_keys = CreateTensorWithDataAsOrtValue(info, allocator, dims, keys); - auto ort_values = CreateTensorWithDataAsOrtValue(info, allocator, dims, values); - auto ort_map = Ort::Value::CreateMap(ort_keys, ort_values); - seq.push_back(std::move(ort_map)); - } - return Ort::Value::CreateSequence(seq); -} - -template -static int ExtractFileNo(const std::basic_string& name) { +static int ExtractFileNo(const std::filesystem::path& pathstr) { + PATH_STRING_TYPE name = pathstr; size_t p1 = name.rfind('.'); size_t p2 = name.rfind('_', p1); ++p2; - std::basic_string number_str = name.substr(p2, p1 - p2); - const CHAR_T* start = number_str.c_str(); - const CHAR_T* end = number_str.c_str(); - long ret = OrtStrtol(start, const_cast(&end)); + PATH_STRING_TYPE number_str = name.substr(p2, p1 - p2); + const PATH_CHAR_TYPE* start = number_str.c_str(); + const PATH_CHAR_TYPE* end = start; + long ret = OrtStrtol(start, const_cast(&end)); if (end == start) { ORT_THROW("parse file name failed"); } return static_cast(ret); } -using PATH_STRING_TYPE = std::basic_string; -static void SortFileNames(std::vector>& input_pb_files) { +static void SortFileNames(std::vector& input_pb_files) { if (input_pb_files.size() <= 1) return; std::sort(input_pb_files.begin(), input_pb_files.end(), - [](const std::basic_string& left, const std::basic_string& right) -> bool { - std::basic_string leftname = GetLastComponent(left); - std::basic_string rightname = GetLastComponent(right); - int left1 = ExtractFileNo(leftname); - int right1 = ExtractFileNo(rightname); + [](const std::filesystem::path& left, std::filesystem::path& right) -> bool { + int left1 = ExtractFileNo(left.filename()); + int right1 = ExtractFileNo(right.filename()); return left1 < right1; }); @@ -158,122 +79,15 @@ static void SortFileNames(std::vector>& input_ } } -Ort::Value TensorToOrtValue(const ONNX_NAMESPACE::TensorProto& t, onnxruntime::test::HeapBuffer& b) { - size_t len = 0; - auto status = onnxruntime::test::GetSizeInBytesFromTensorProto<0>(t, &len); - if (!status.IsOK()) { - ORT_THROW(status.ToString()); - } - void* p = len == 0 ? nullptr : b.AllocMemory(len); - Ort::Value temp_value{nullptr}; - onnxruntime::test::OrtCallback d; - auto cpu_memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - status = onnxruntime::test::TensorProtoToMLValue(t, onnxruntime::test::MemBuffer(p, len, *static_cast(cpu_memory_info)), - temp_value, d); - if (!status.IsOK()) { - ORT_THROW(status.ToString()); - } - if (d.f) { - b.AddDeleter(d); - } - return temp_value; -} - -void LoopDataFile(int test_data_pb_fd, bool is_input, const TestModelInfo& modelinfo, - std::unordered_map& name_data_map, onnxruntime::test::HeapBuffer& b, - std::ostringstream& oss) { - google::protobuf::io::FileInputStream f(test_data_pb_fd, protobuf_block_size_in_bytes); - f.SetCloseOnDelete(true); - google::protobuf::io::CodedInputStream coded_input(&f); - bool clean_eof = false; - [[maybe_unused]] int item_id = 1; - for (proto::TraditionalMLData data; - ParseDelimitedFromCodedStream(&data, &coded_input, &clean_eof); - ++item_id, data.Clear()) { - ORT_TRY { - Ort::Value gvalue{nullptr}; - switch (data.values_case()) { - case proto::TraditionalMLData::kVectorMapStringToFloat: - gvalue = VectorProtoToOrtValue(data.vector_map_string_to_float().v()); - break; - case proto::TraditionalMLData::kVectorMapInt64ToFloat: - gvalue = VectorProtoToOrtValue(data.vector_map_int64_to_float().v()); - break; - case proto::TraditionalMLData::kMapStringToString: - gvalue = PbMapToOrtValue(data.map_string_to_string().v()); - break; - case proto::TraditionalMLData::kMapStringToInt64: - gvalue = PbMapToOrtValue(data.map_string_to_int64().v()); - break; - case proto::TraditionalMLData::kMapStringToFloat: - gvalue = PbMapToOrtValue(data.map_string_to_float().v()); - break; - case proto::TraditionalMLData::kMapStringToDouble: - gvalue = PbMapToOrtValue(data.map_string_to_double().v()); - break; - case proto::TraditionalMLData::kMapInt64ToString: - gvalue = PbMapToOrtValue(data.map_int64_to_string().v()); - break; - case proto::TraditionalMLData::kMapInt64ToInt64: - gvalue = PbMapToOrtValue(data.map_int64_to_int64().v()); - break; - case proto::TraditionalMLData::kMapInt64ToFloat: - gvalue = PbMapToOrtValue(data.map_int64_to_float().v()); - break; - case proto::TraditionalMLData::kMapInt64ToDouble: - gvalue = PbMapToOrtValue(data.map_int64_to_double().v()); - break; - case proto::TraditionalMLData::kTensor: { - gvalue = TensorToOrtValue(data.tensor(), b); - } break; - default: - ORT_NOT_IMPLEMENTED("unknown data type inside TraditionalMLData"); - } - if (!data.debug_info().empty()) { - oss << ":" << data.debug_info(); - } - std::string value_name = data.name(); - if (value_name.empty()) { - const size_t c = name_data_map.size(); - value_name = is_input ? modelinfo.GetInputName(c) : modelinfo.GetOutputName(c); - } - - auto p = name_data_map.emplace(value_name, std::move(gvalue)); - if (!p.second) { - ORT_THROW("duplicated test data name"); - break; - } - } - ORT_CATCH(onnxruntime::NotImplementedException & ex) { - ORT_HANDLE_EXCEPTION([&]() { - std::ostringstream oss2; - oss2 << "load the " << item_id << "-th item failed," << ex.what(); - ORT_NOT_IMPLEMENTED(oss2.str()); - }); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - std::ostringstream oss2; - oss2 << "load the " << item_id << "-th item failed," << ex.what(); - ORT_THROW(oss2.str()); - }); - } - } - if (!clean_eof) { - ORT_THROW("parse input file failed, has extra unparsed data"); - } -} - } // namespace - #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr TestModelInfo::LoadOnnxModel(const std::filesystem::path& model_url) { - return std::make_unique(model_url); + return std::make_unique(model_url); } #endif std::unique_ptr TestModelInfo::LoadOrtModel(const std::filesystem::path& model_url) { - return std::make_unique(model_url, true); + return std::make_unique(model_url, true); } /** @@ -481,53 +295,7 @@ void OnnxTestCase::LoadTestData(size_t id, onnxruntime::test::HeapBuffer& b, ORT_THROW("index out of bound"); } - std::filesystem::path test_data_pb = - test_data_dirs_[id] / (is_input ? ORT_TSTR("inputs.pb") : ORT_TSTR("outputs.pb")); - int test_data_pb_fd; - auto st = Env::Default().FileOpenRd(test_data_pb.string(), test_data_pb_fd); - if (st.IsOK()) { // has an all-in-one input file - std::ostringstream oss; - { - std::lock_guard l(m_); - oss << debuginfo_strings_[id]; - } - ORT_TRY { - LoopDataFile(test_data_pb_fd, is_input, *model_info_, name_data_map, b, oss); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - std::ostringstream oss2; - oss2 << "parse data file \"" << ToUTF8String(test_data_pb) << "\" failed:" << ex.what(); - ORT_THROW(oss.str()); - }); - } - - { - std::lock_guard l(m_); - debuginfo_strings_[id] = oss.str(); - } - return; - } - - std::vector test_data_pb_files; - - std::filesystem::path dir_fs_path = test_data_dirs_[id]; - if (!std::filesystem::exists(dir_fs_path)) return; - - for (auto const& dir_entry : std::filesystem::directory_iterator(dir_fs_path)) { - if (!dir_entry.is_regular_file()) continue; - const std::filesystem::path& path = dir_entry.path(); - if (!path.filename().has_extension()) { - continue; - } - if (path.filename().extension().compare(ORT_TSTR(".pb")) != 0) continue; - const std::basic_string file_prefix = - is_input ? ORT_TSTR("input_") : ORT_TSTR("output_"); - auto filename_str = path.filename().native(); - if (filename_str.compare(0, file_prefix.length(), file_prefix) == 0) { - test_data_pb_files.push_back(path.native()); - } - } + std::vector test_data_pb_files = SimpleGlob(test_data_dirs_[id], is_input ? ORT_TSTR("input_*.pb") : ORT_TSTR("output_*.pb")); SortFileNames(test_data_pb_files); @@ -734,121 +502,105 @@ OnnxTestCase::OnnxTestCase(const std::string& test_case_name, _In_ std::unique_p } } -void LoadTests(const std::vector>& input_paths, - const std::vector>& whitelisted_test_cases, - const TestTolerances& tolerances, - const std::unordered_set>& disabled_tests, - std::unique_ptr> broken_tests, - std::unique_ptr> broken_tests_keyword_set, - const std::function)>& process_function) { - std::vector> paths(input_paths); - while (!paths.empty()) { - std::filesystem::path node_data_root_path = paths.back(); - paths.pop_back(); - if (!std::filesystem::exists(node_data_root_path)) continue; - std::filesystem::path my_dir_name = node_data_root_path.filename(); - for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) { - if (dir_entry.is_directory()) { - paths.push_back(dir_entry.path()); - continue; - } - if (!dir_entry.is_regular_file()) continue; - std::filesystem::path filename_str = dir_entry.path().filename(); - if (filename_str.empty() || filename_str.native()[0] == ORT_TSTR('.')) { - // Ignore hidden files. - continue; - } - bool is_onnx_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".onnx")) == 0); - bool is_ort_format = filename_str.has_extension() && (filename_str.extension().compare(ORT_TSTR(".ort")) == 0); - bool is_valid_model = false; - -#if !defined(ORT_MINIMAL_BUILD) - is_valid_model = is_onnx_format; -#endif - - is_valid_model = is_valid_model || is_ort_format; - if (!is_valid_model) - continue; - - std::basic_string test_case_name = my_dir_name.native(); - if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5); +bool IsValidTest(std::basic_string test_case_name, const std::vector>& whitelisted_test_cases, const std::unordered_set>& disabled_tests) { + if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5); - if (!whitelisted_test_cases.empty() && std::find(whitelisted_test_cases.begin(), whitelisted_test_cases.end(), - test_case_name) == whitelisted_test_cases.end()) { - continue; - } - if (disabled_tests.find(test_case_name) != disabled_tests.end()) continue; + if (!whitelisted_test_cases.empty() && std::find(whitelisted_test_cases.begin(), whitelisted_test_cases.end(), + test_case_name) == whitelisted_test_cases.end()) { + return false; + } + return disabled_tests.find(test_case_name) == disabled_tests.end(); +} - std::unique_ptr model_info; +void LoadSingleModel(std::unique_ptr model_info, const TestTolerances& tolerances, std::unique_ptr>& broken_tests, + std::unique_ptr>& broken_tests_keyword_set, + const std::function)>& process_function) { + auto test_case_dir = model_info->GetDir(); + auto test_case_name = test_case_dir.filename().native(); + if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) test_case_name = test_case_name.substr(5); + auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native(); - if (is_onnx_format) { -#if !defined(ORT_MINIMAL_BUILD) - model_info = TestModelInfo::LoadOnnxModel(dir_entry.path()); -#else - ORT_THROW("onnx model is not supported in this build"); +#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) && !defined(USE_VSINPU) + // to skip some models like *-int8 or *-qdq + if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || + (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain"); + return; + } #endif - } else if (is_ort_format) { - model_info = TestModelInfo::LoadOrtModel(dir_entry.path()); - } else { - ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported"); - } - auto test_case_dir = model_info->GetDir(); - auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir.native(); + if (broken_tests) { + BrokenTest t = {ToUTF8String(test_case_name), ""}; + auto iter = broken_tests->find(t); + auto opset_version = model_info->GetNominalOpsetVersion(); + if (iter != broken_tests->end() && + (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() || + iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests"); + return; + } + } -#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) && !defined(USE_VSINPU) - // to skip some models like *-int8 or *-qdq - if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || - (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain"); - continue; + if (broken_tests_keyword_set) { + for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) { + std::string keyword = *iter2; + if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords"); + return; } -#endif + } + } - bool has_test_data = false; - LoopDir(test_case_dir, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool { - if (filename[0] == '.') return true; - if (f_type == OrtFileType::TYPE_DIR) { - has_test_data = true; - return false; - } - return true; - }); - if (!has_test_data) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data"); - continue; - } + const auto tolerance_key = ToUTF8String(test_case_dir.filename()); + + std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), + tolerances.absolute(tolerance_key), + tolerances.relative(tolerance_key)); + fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str()); + process_function(std::move(l)); +} - if (broken_tests) { - BrokenTest t = {ToUTF8String(test_case_name), ""}; - auto iter = broken_tests->find(t); - auto opset_version = model_info->GetNominalOpsetVersion(); - if (iter != broken_tests->end() && - (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() || - iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests"); +void LoadTests(const std::vector>& input_paths, + const std::vector>& whitelisted_test_cases, + const TestTolerances& tolerances, + const std::unordered_set>& disabled_tests, + std::unique_ptr> broken_tests, + std::unique_ptr> broken_tests_keyword_set, + const std::function)>& process_function) { + std::vector onnx_models; + std::vector ort_models; + for (const std::basic_string& path_str : input_paths) { + ORT_TRY { + for (auto& dir_entry : std::filesystem::recursive_directory_iterator(path_str)) { + if (!dir_entry.is_regular_file() || dir_entry.is_directory()) continue; + std::filesystem::path node_data_root_path = dir_entry.path(); + std::filesystem::path filename_str = dir_entry.path().filename(); + if (filename_str.empty() || filename_str.native()[0] == ORT_TSTR('.')) { + // Ignore hidden files. continue; } - } - - if (broken_tests_keyword_set) { - for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) { - std::string keyword = *iter2; - if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords"); - continue; - } + auto folder_path = node_data_root_path.parent_path().native(); + if (FnmatchSimple(ORT_TSTR("*.onnx"), filename_str.native()) && IsValidTest(folder_path, whitelisted_test_cases, disabled_tests)) { + onnx_models.push_back(node_data_root_path); + } else if (FnmatchSimple(ORT_TSTR("*.ort"), filename_str.native()) && IsValidTest(folder_path, whitelisted_test_cases, disabled_tests)) { + ort_models.push_back(node_data_root_path); } } - - const auto tolerance_key = ToUTF8String(my_dir_name); - - std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), - tolerances.absolute(tolerance_key), - tolerances.relative(tolerance_key)); - fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str()); - process_function(std::move(l)); } + ORT_CATCH(const std::filesystem::filesystem_error&) { + // silently ignore the directories that do not exist + } + } + +#if !defined(ORT_MINIMAL_BUILD) + // The for-loop below needs to load every ONNX model into memory then destory the in-memory objects, which is very inefficient since 1. in total we need to load every model twice 2. at here we do the job sequentially. + // Originally the design was to make the TestModelInfo lightweight so that all the model information can be retrieved from filesystem meta data without actually loading the models. + for (const std::filesystem::path& model_path : onnx_models) { + LoadSingleModel(TestModelInfo::LoadOnnxModel(model_path), tolerances, broken_tests, broken_tests_keyword_set, process_function); + } +#endif + for (const std::filesystem::path& model_path : ort_models) { + LoadSingleModel(TestModelInfo::LoadOrtModel(model_path), tolerances, broken_tests, broken_tests_keyword_set, process_function); } } diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 745a1fe9eeb50..66fe524c5bea8 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -10,15 +10,12 @@ #include #include #include +#include "onnx_model_info.h" namespace Ort { struct Value; } -namespace ONNX_NAMESPACE { -class ValueInfoProto; -} - namespace onnxruntime { namespace test { class HeapBuffer; @@ -49,32 +46,6 @@ class ITestCase { virtual void GetPostProcessing(bool* value) const = 0; }; -class TestModelInfo { - public: - virtual const std::filesystem::path& GetModelUrl() const = 0; - virtual std::filesystem::path GetDir() const { - const auto& p = GetModelUrl(); - return p.has_parent_path() ? p.parent_path() : std::filesystem::current_path(); - } - virtual const std::string& GetNodeName() const = 0; - virtual const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const = 0; - virtual const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const = 0; - virtual int GetInputCount() const = 0; - virtual int GetOutputCount() const = 0; - virtual const std::string& GetInputName(size_t i) const = 0; - virtual const std::string& GetOutputName(size_t i) const = 0; - virtual std::string GetNominalOpsetVersion() const { return ""; } - virtual ~TestModelInfo() = default; - -#if !defined(ORT_MINIMAL_BUILD) - static std::unique_ptr LoadOnnxModel(const std::filesystem::path& model_url); -#endif - - static std::unique_ptr LoadOrtModel(const std::filesystem::path& model_url); - - static const std::string unknown_version; -}; - std::unique_ptr CreateOnnxTestCase(const std::string& test_case_name, std::unique_ptr model, double default_per_sample_tolerance, diff --git a/onnxruntime/test/onnx/onnx_model_info.cc b/onnxruntime/test/onnx/onnx_model_info.cc index f23012aee9fd2..1e5684ced512b 100644 --- a/onnxruntime/test/onnx/onnx_model_info.cc +++ b/onnxruntime/test/onnx/onnx_model_info.cc @@ -14,7 +14,7 @@ using namespace onnxruntime; -OnnxModelInfo::OnnxModelInfo(const std::filesystem::path& model_url, bool is_ort_model) +TestModelInfo::TestModelInfo(const std::filesystem::path& model_url, bool is_ort_model) : model_url_(model_url) { if (is_ort_model) { InitOrtModelInfo(model_url); @@ -38,7 +38,7 @@ static void RepeatedPtrFieldToVector(const ::google::protobuf::RepeatedPtrField< } } -void OnnxModelInfo::InitOnnxModelInfo(const std::filesystem::path& model_url) { // parse model +void TestModelInfo::InitOnnxModelInfo(const std::filesystem::path& model_url) { // parse model int model_fd; auto st = Env::Default().FileOpenRd(model_url, model_fd); if (!st.IsOK()) { @@ -93,7 +93,7 @@ void OnnxModelInfo::InitOnnxModelInfo(const std::filesystem::path& model_url) { #endif // #if !defined(ORT_MINIMAL_BUILD) -void OnnxModelInfo::InitOrtModelInfo(const std::filesystem::path& model_url) { +void TestModelInfo::InitOrtModelInfo(const std::filesystem::path& model_url) { std::vector bytes; size_t num_bytes = 0; const auto model_location = ToWideString(model_url); diff --git a/onnxruntime/test/onnx/onnx_model_info.h b/onnxruntime/test/onnx/onnx_model_info.h index 48e297376aff5..66fc192aba063 100644 --- a/onnxruntime/test/onnx/onnx_model_info.h +++ b/onnxruntime/test/onnx/onnx_model_info.h @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once +#include +#include +#include +#include -#include "core/graph/onnx_protobuf.h" -#include "TestCase.h" - -class OnnxModelInfo : public TestModelInfo { +class TestModelInfo { private: std::string node_name_; // Due to performance, the opset version is get from directory name, so it's nominal @@ -22,7 +23,15 @@ class OnnxModelInfo : public TestModelInfo { void InitOrtModelInfo(const std::filesystem::path& model_url); public: - OnnxModelInfo(const std::filesystem::path& path, bool is_ort_model = false); +#if !defined(ORT_MINIMAL_BUILD) + static std::unique_ptr LoadOnnxModel(const std::filesystem::path& model_url); +#endif + static std::unique_ptr LoadOrtModel(const std::filesystem::path& model_url); + TestModelInfo(const std::filesystem::path& path, bool is_ort_model = false); + std::filesystem::path GetDir() const { + const auto& p = GetModelUrl(); + return p.has_parent_path() ? p.parent_path() : std::filesystem::current_path(); + } bool HasDomain(const std::string& name) const { return domain_to_version_.find(name) != domain_to_version_.end(); } @@ -32,21 +41,23 @@ class OnnxModelInfo : public TestModelInfo { return iter == domain_to_version_.end() ? -1 : iter->second; } - const std::filesystem::path& GetModelUrl() const override { return model_url_; } - std::string GetNominalOpsetVersion() const override { return onnx_nominal_opset_vesion_; } + const std::filesystem::path& GetModelUrl() const { return model_url_; } + std::string GetNominalOpsetVersion() const { return onnx_nominal_opset_vesion_; } - const std::string& GetNodeName() const override { return node_name_; } + const std::string& GetNodeName() const { return node_name_; } - const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const override { + const ONNX_NAMESPACE::ValueInfoProto* GetInputInfoFromModel(size_t i) const { return &input_value_info_[i]; } - const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const override { + const ONNX_NAMESPACE::ValueInfoProto* GetOutputInfoFromModel(size_t i) const { return &output_value_info_[i]; } - int GetInputCount() const override { return static_cast(input_value_info_.size()); } - int GetOutputCount() const override { return static_cast(output_value_info_.size()); } - const std::string& GetInputName(size_t i) const override { return input_value_info_[i].name(); } - const std::string& GetOutputName(size_t i) const override { return output_value_info_[i].name(); } + int GetInputCount() const { return static_cast(input_value_info_.size()); } + int GetOutputCount() const { return static_cast(output_value_info_.size()); } + const std::string& GetInputName(size_t i) const { return input_value_info_[i].name(); } + const std::string& GetOutputName(size_t i) const { return output_value_info_[i].name(); } + + static const std::string unknown_version; }; diff --git a/onnxruntime/test/onnx/pb_helper.h b/onnxruntime/test/onnx/pb_helper.h index cd73c53c1979d..01c017d28d082 100644 --- a/onnxruntime/test/onnx/pb_helper.h +++ b/onnxruntime/test/onnx/pb_helper.h @@ -42,7 +42,6 @@ #include #include #include -#include "tml.pb.h" #ifdef __GNUC__ #pragma GCC diagnostic pop #endif diff --git a/onnxruntime/test/proto/tml.proto b/onnxruntime/test/proto/tml.proto deleted file mode 100644 index 2071e2a55dcff..0000000000000 --- a/onnxruntime/test/proto/tml.proto +++ /dev/null @@ -1,108 +0,0 @@ -syntax = "proto2"; - -import "onnx/onnx-ml.proto"; - -package onnxruntime.proto; - -//must sync with data_types.h -//MapStringToString -//MapStringToInt64 -//MapStringToFloat -//MapStringToDouble -//MapInt64ToString -//MapInt64ToInt64 -//MapInt64ToFloat -//MapInt64ToDouble -//VectorString -//VectorFloat -//VectorInt64 -//VectorDouble -//VectorMapStringToFloat -//VectorMapInt64ToFloat - -message MapStringToString { - map v = 1; -} - -message MapStringToInt64 { - map v = 1; -} - -message MapStringToDouble { - map v = 1; -} - -message MapStringToFloat { - map v = 1; -} - -message MapInt64ToString { - map v = 1; -} - -message MapInt64ToInt64 { - map v = 1; -} - -message MapInt64ToFloat { - map v = 1; -} - -message MapInt64ToDouble { - map v = 1; -} - -message VectorString { - repeated string v = 1; -} - -message VectorFloat { - repeated float v = 1; -} - -message VectorInt64 { - repeated int64 v = 1; -} - -message VectorDouble { - repeated double v = 1; -} - -message VectorMapStringToFloat { - repeated MapStringToFloat v = 1; -} - -message VectorMapInt64ToFloat { - repeated MapInt64ToFloat v = 1; -} - -message TraditionalMLData { - oneof values { - MapStringToString map_string_to_string = 1; - MapStringToInt64 map_string_to_int64 = 2; - MapStringToFloat map_string_to_float = 3; - MapStringToDouble map_string_to_double = 4; - MapInt64ToString map_int64_to_string = 5; - MapInt64ToInt64 map_int64_to_int64 = 6; - MapInt64ToFloat map_int64_to_float = 7; - MapInt64ToDouble map_int64_to_double = 8; - VectorString vector_string = 9; - VectorFloat vector_float = 10; - VectorInt64 vector_int64 = 11; - VectorDouble vector_double = 12; - VectorMapStringToFloat vector_map_string_to_float = 13; - VectorMapInt64ToFloat vector_map_int64_to_float = 14; - onnx.TensorProto tensor = 16; - } - // Optionally, a name for the tensor. - optional string name = 15; - optional string debug_info = 17; //A human-readable string for helping debugging -} - -message TestCaseConfig { - optional double per_sample_tolerance = 1; - optional double relative_per_sample_tolerance = 2; - optional bool post_processing = 3; -} -// For using protobuf-lite -option optimize_for = LITE_RUNTIME; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index e3c86a137484f..ca36798cfa79d 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -21,7 +21,6 @@ #include "asserts.h" #include #include "default_providers.h" -#include "test/onnx/TestCase.h" #ifdef USE_DNNL #include "core/providers/dnnl/dnnl_provider_factory.h" @@ -52,6 +51,7 @@ // test infrastructure #include "test/onnx/testenv.h" #include "test/onnx/TestCase.h" +#include "test/onnx/fnmatch_simple.h" #include "test/compare_ortvalue.h" #include "test/onnx/heap_buffer.h" #include "test/onnx/onnx_model_info.h" @@ -102,7 +102,7 @@ TEST_P(ModelTest, Run) { } } - std::unique_ptr model_info = std::make_unique(model_path.c_str()); + std::unique_ptr model_info = std::make_unique(model_path.c_str()); if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { @@ -746,45 +746,42 @@ ::std::vector<::std::basic_string> GetParameterStrings() { all_disabled_tests.insert(ORT_TSTR("fp16_inception_v1")); all_disabled_tests.insert(ORT_TSTR("fp16_tiny_yolov2")); - while (!paths.empty()) { - std::filesystem::path node_data_root_path = paths.back(); - paths.pop_back(); - if (!std::filesystem::exists(node_data_root_path) || !std::filesystem::is_directory(node_data_root_path)) { - continue; - } - for (auto const& dir_entry : std::filesystem::directory_iterator(node_data_root_path)) { - if (dir_entry.is_directory()) { - paths.push_back(dir_entry.path()); - continue; - } - const std::filesystem::path& path = dir_entry.path(); - if (!path.has_filename() || path.filename().native().compare(0, 1, ORT_TSTR(".")) == 0) { - // Ignore hidden files. - continue; - } - if (path.filename().extension().compare(ORT_TSTR(".onnx")) != 0) { - // Ignore the files that are not ONNX models - continue; - } - std::basic_string test_case_name = path.parent_path().filename().native(); - if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) - test_case_name = test_case_name.substr(5); - if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end()) - continue; + for (const std::filesystem::path& root_dir : paths) { + ORT_TRY { + for (auto& dir_entry : std::filesystem::recursive_directory_iterator(root_dir)) { + if (!dir_entry.is_regular_file() || dir_entry.is_directory()) continue; + std::filesystem::path node_data_root_path = dir_entry.path(); + std::filesystem::path filename_str = dir_entry.path().filename(); + if (filename_str.empty() || filename_str.native()[0] == ORT_TSTR('.')) { + // Ignore hidden files. + continue; + } + auto folder_path = node_data_root_path.parent_path().native(); + if (FnmatchSimple(ORT_TSTR("*.onnx"), filename_str.native())) { + std::basic_string test_case_name = node_data_root_path.parent_path().filename().native(); + if (test_case_name.compare(0, 5, ORT_TSTR("test_")) == 0) + test_case_name = test_case_name.substr(5); + if (all_disabled_tests.find(test_case_name) != all_disabled_tests.end()) + continue; #ifdef DISABLE_ML_OPS - auto starts_with = [](const std::basic_string& find_in, - const std::basic_string& find_what) { - return find_in.compare(0, find_what.size(), find_what) == 0; - }; - if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) || - starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) { - continue; - } + auto starts_with = [](const std::basic_string& find_in, + const std::basic_string& find_what) { + return find_in.compare(0, find_what.size(), find_what) == 0; + }; + if (starts_with(test_case_name, ORT_TSTR("XGBoost_")) || starts_with(test_case_name, ORT_TSTR("coreml_")) || + starts_with(test_case_name, ORT_TSTR("scikit_")) || starts_with(test_case_name, ORT_TSTR("libsvm_"))) { + continue; + } #endif - std::basic_ostringstream oss; - oss << provider_name << ORT_TSTR("_") << path.native(); - v.emplace_back(oss.str()); + std::basic_ostringstream oss; + oss << provider_name << ORT_TSTR("_") << node_data_root_path.native(); + v.emplace_back(oss.str()); + } + } + } + ORT_CATCH(const std::filesystem::filesystem_error&) { + // silently ignore the directories that do not exist } } }