Skip to content

Commit

Permalink
fixup! Ensure we support all inputs for MatMulInteger and ConvInteger…
Browse files Browse the repository at this point in the history
…. Limit… (#53) (#56)
  • Loading branch information
TedThemistokleous authored Aug 22, 2024
1 parent 539fc8c commit 9e8a798
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 187 deletions.
243 changes: 115 additions & 128 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,159 +381,146 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
}
}
else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
return true;
}
}
else if (optype == "OneHot") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
} else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
return true;
}
} else if (optype == "OneHot") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}

const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}

// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
} else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
}
}
else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;

const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
return true;
}
}
else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}
}

auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
return true;
}
}
}
} else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
}
else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
} else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
} else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
return true;
}
}
return true;
}
}
else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}

const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
return true;
}
}
}
}
else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
if (dims == std::vector<std::size_t>{0}) {
} else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
return true;
}
}
}
return true;
}
}
else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
} else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
}
if (dims == std::vector<std::size_t>{0}) {
return true;
}
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
} else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
} else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
} else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
return true;
}
}

// Op doesn't fall into known any of unsupported modes.
return false;
// Op doesn't fall into known any of unsupported modes.
return false;
}

void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector<std::vector<NodeIndex>>& clusters,
Expand Down
Loading

0 comments on commit 9e8a798

Please sign in to comment.