From b61527000f3a93f4ffba4e313e1617cfb10f761c Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 22 Aug 2024 13:28:25 +0000 Subject: [PATCH] =?UTF-8?q?fixup!=20Ensure=20we=20support=20all=20inputs?= =?UTF-8?q?=20for=20MatMulInteger=20and=20ConvInteger.=20Limit=E2=80=A6=20?= =?UTF-8?q?(#53)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migraphx/migraphx_execution_provider.cc | 243 +++++++++--------- .../python/onnxruntime_pybind_state.cc | 88 +++---- 2 files changed, 144 insertions(+), 187 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ed1bec59ca736..9ef24c176e468 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -381,159 +381,146 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) { return true; } - } -} -else if (optype == "NonZero") { - if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) { - return true; - } -} -else if (optype == "OneHot") { - if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return true; - } -} -else if (optype == "Pad") { - const auto& args = node->InputDefs(); - // if pad size is not constant, migraphx cannot support - if (args.size() >= 2) { + } else if (optype == "NonZero") { + if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) { + return true; + } + } else if (optype == "OneHot") { if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } - } - - const auto& attributes = node->GetAttributes(); - // Pad only support constant mode - auto mode_attr = attributes.find("mode"); - std::string mode = "constant"; - if (mode_attr != attributes.end()) { - mode = (*mode_attr).second.s(); - } - static const std::set allowed_modes = {"constant", "reflect"}; - if (allowed_modes.count(mode) == 0) { - return true; - } - - // input value only applied to constant mode - if (mode == "constant") { - if (args.size() == 3) { - if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) { + } else if (optype == "Pad") { + const auto& args = node->InputDefs(); + // if pad size is not constant, migraphx cannot support + if (args.size() >= 2) { + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { return true; } } - } -} -else if (optype == "Range") { - auto arg_num = node->InputDefs().size(); - std::vector vec(arg_num); - std::iota(vec.begin(), vec.end(), 0); - if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { - return true; - } -} -else if (optype == "Reshape") { - const auto& args = node->InputDefs(); - if (args.size() == 2) { - if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return false; + + const auto& attributes = node->GetAttributes(); + // Pad only support constant mode + auto mode_attr = attributes.find("mode"); + std::string mode = "constant"; + if (mode_attr != attributes.end()) { + mode = (*mode_attr).second.s(); } - return true; - } -} -else if (optype == "Resize" or optype == "Upsample") { - const auto& attributes = node->GetAttributes(); - auto ct_attr = attributes.find("coordinate_transformation_mode"); - if (ct_attr != attributes.end()) { - auto ct = (*ct_attr).second.s(); - if (ct == "tf_crop_and_resize") { + static const std::set allowed_modes = {"constant", "reflect"}; + if (allowed_modes.count(mode) == 0) { return true; } - } - auto mode_attr = attributes.find("mode"); - if (mode_attr != attributes.end()) { - auto mode = (*mode_attr).second.s(); - if (mode == "cubic") { + // input value only applied to constant mode + if (mode == "constant") { + if (args.size() == 3) { + if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) { + return true; + } + } + } + } else if (optype == "Range") { + auto arg_num = node->InputDefs().size(); + std::vector vec(arg_num); + std::iota(vec.begin(), vec.end(), 0); + if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { return true; } - } -} -else if (optype == "ReduceSum") { - const auto& args = node->InputDefs(); - if (args.size() == 2) { - if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return false; + } else if (optype == "Reshape") { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return false; + } + return true; + } + } else if (optype == "Resize" or optype == "Upsample") { + const auto& attributes = node->GetAttributes(); + auto ct_attr = attributes.find("coordinate_transformation_mode"); + if (ct_attr != attributes.end()) { + auto ct = (*ct_attr).second.s(); + if (ct == "tf_crop_and_resize") { + return true; + } } - return true; - } -} -else if (optype == "Slice") { - // MIGraphX does not properly handle the situation where any - // value of the "starts" attribute is higher than a corresponding - // value in the "ends" - auto arg_num = node->InputDefs().size(); - std::vector vec(arg_num); - std::iota(vec.begin(), vec.end(), 0); - vec.erase(vec.begin()); - if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { - return true; - } - const auto& attributes = node->GetAttributes(); - if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { - auto starts = toVector((*attributes.find("starts")).second.ints()); - auto ends = toVector((*attributes.find("ends")).second.ints()); - for (std::size_t i = 0; i < starts.size(); ++i) { - if (starts.at(i) > ends.at(i)) { + auto mode_attr = attributes.find("mode"); + if (mode_attr != attributes.end()) { + auto mode = (*mode_attr).second.s(); + if (mode == "cubic") { return true; } } - } -} -else if (optype == "Split") { - // cannot process input dim of 0 size - const auto arg_s = node->InputDefs()[0]->Shape(); - if (arg_s != nullptr) { - const auto& tensor_dims = arg_s->dim(); - std::vector dims; - for (auto&& dim : tensor_dims) { - dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0); + } else if (optype == "ReduceSum") { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return false; + } + return true; } - if (dims == std::vector{0}) { + } else if (optype == "Slice") { + // MIGraphX does not properly handle the situation where any + // value of the "starts" attribute is higher than a corresponding + // value in the "ends" + auto arg_num = node->InputDefs().size(); + std::vector vec(arg_num); + std::iota(vec.begin(), vec.end(), 0); + vec.erase(vec.begin()); + if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) { return true; } - } - const auto& args = node->InputDefs(); - if (args.size() == 2) { - if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return false; + const auto& attributes = node->GetAttributes(); + if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + auto starts = toVector((*attributes.find("starts")).second.ints()); + auto ends = toVector((*attributes.find("ends")).second.ints()); + for (std::size_t i = 0; i < starts.size(); ++i) { + if (starts.at(i) > ends.at(i)) { + return true; + } + } } - return true; - } -} -else if (optype == "Tile") { - if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return true; - } -} -else if (optype == "TopK") { - if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return true; - } -} -else if (optype == "Unsqueeze" or optype == "Squeeze") { - const auto& args = node->InputDefs(); - if (args.size() == 2) { - if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { - return false; + } else if (optype == "Split") { + // cannot process input dim of 0 size + const auto arg_s = node->InputDefs()[0]->Shape(); + if (arg_s != nullptr) { + const auto& tensor_dims = arg_s->dim(); + std::vector dims; + for (auto&& dim : tensor_dims) { + dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0); + } + if (dims == std::vector{0}) { + return true; + } + } + + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return false; + } + return true; + } + } else if (optype == "Tile") { + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return true; + } + } else if (optype == "TopK") { + if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return true; + } + } else if (optype == "Unsqueeze" or optype == "Squeeze") { + const auto& args = node->InputDefs(); + if (args.size() == 2) { + if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) { + return false; + } + return true; } - return true; } -} -// Op doesn't fall into known any of unsupported modes. -return false; + // Op doesn't fall into known any of unsupported modes. + return false; } void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 4daf6c640e464..8bd7bf80ebd85 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1346,7 +1346,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(algo); ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM"); #else - cudnn_conv_algo_search = algo; + cudnn_conv_algo_search = algo; #endif }); // TODO remove deprecated global config @@ -1357,7 +1357,7 @@ void addGlobalMethods(py::module& m) { ORT_UNUSED_PARAMETER(use_single_stream); ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM"); #else - do_copy_in_default_stream = use_single_stream; + do_copy_in_default_stream = use_single_stream; #endif }); // TODO remove deprecated global config @@ -1721,10 +1721,10 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") } ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs)); #else - ORT_UNUSED_PARAMETER(options); - ORT_UNUSED_PARAMETER(names); - ORT_UNUSED_PARAMETER(ort_values); - ORT_THROW("External initializers are not supported in this build."); + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(names); + ORT_UNUSED_PARAMETER(ort_values); + ORT_THROW("External initializers are not supported in this build."); #endif }); @@ -1786,8 +1786,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") return *(na.Type()); }, "node type") - .def( - "__str__", [](const onnxruntime::NodeArg& na) -> std::string { + .def("__str__", [](const onnxruntime::NodeArg& na) -> std::string { std::ostringstream res; res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape="; auto shape = na.Shape(); @@ -1813,11 +1812,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") } res << ")"; - return std::string(res.str()); - }, - "converts the node into a readable string") - .def_property_readonly( - "shape", [](const onnxruntime::NodeArg& na) -> std::vector { + return std::string(res.str()); }, "converts the node into a readable string") + .def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector { auto shape = na.Shape(); std::vector arr; if (shape == nullptr || shape->dim_size() == 0) { @@ -1834,9 +1830,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") arr[i] = py::none(); } } - return arr; - }, - "node shape (assuming the node holds a tensor)"); + return arr; }, "node shape (assuming the node holds a tensor)"); py::class_ sessionObjectInitializer(m, "SessionObjectInitializer"); py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") @@ -2024,51 +2018,28 @@ including arg name, arg type (contains both type and shape).)pbdoc") .def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t { return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs(); }) - .def( - "get_providers", [](const PyInferenceSession* sess) -> const std::vector& { - return sess->GetSessionHandle()->GetRegisteredProviderTypes(); - }, - py::return_value_policy::reference_internal) - .def( - "get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { - return sess->GetSessionHandle()->GetAllProviderOptions(); - }, - py::return_value_policy::reference_internal) - .def_property_readonly( - "session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { + .def("get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) + .def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) + .def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { auto session_options = std::make_unique(); session_options->value = sess->GetSessionHandle()->GetSessionOptions(); - return session_options.release(); - }, - py::return_value_policy::take_ownership) - .def_property_readonly( - "inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + return session_options.release(); }, py::return_value_policy::take_ownership) + .def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelInputs(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) - .def_property_readonly( - "outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { + return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelOutputs(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) - .def_property_readonly( - "overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { + return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetOverridableInitializers(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) - .def_property_readonly( - "model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { + return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) + return *(res.second); }, py::return_value_policy::reference_internal) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; // release GIL to allow multiple python threads to invoke Run() in parallel. @@ -2078,8 +2049,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") else status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) - throw std::runtime_error("Error in execution: " + status.ErrorMessage()); - }) + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { #if !defined(ORT_MINIMAL_BUILD) py::list ret; @@ -2093,8 +2063,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") return ret; #else - ORT_UNUSED_PARAMETER(sess); - ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("TunableOp and get_tuning_results are not supported in this build."); #endif }) .def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void { @@ -2125,10 +2095,10 @@ including arg name, arg type (contains both type and shape).)pbdoc") throw std::runtime_error("Error in execution: " + status.ErrorMessage()); } #else - ORT_UNUSED_PARAMETER(sess); - ORT_UNUSED_PARAMETER(results); - ORT_UNUSED_PARAMETER(error_on_invalid); - ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); + ORT_UNUSED_PARAMETER(sess); + ORT_UNUSED_PARAMETER(results); + ORT_UNUSED_PARAMETER(error_on_invalid); + ORT_THROW("TunableOp and set_tuning_results are not supported in this build."); #endif });