Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixup! Ensure we support all inputs for MatMulInteger and ConvInteger… #57

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 115 additions & 128 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,159 +381,146 @@
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
}
}
else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
return true;
}
}
else if (optype == "OneHot") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
} else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
return true;
}
} else if (optype == "OneHot") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}

const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}

// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
} else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
}
}
else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;

const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
return true;
}
}
else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}
}

auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
return true;
}
}
}
} else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
}
else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
} else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
} else if (optype == "Resize" or optype == "Upsample") {

Check warning on line 436 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L436

Use operator || instead of or [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:436:  Use operator || instead of or  [readability/alt_tokens] [2]
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
return true;
}
}
return true;
}
}
else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}

const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
return true;
}
}
}
}
else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
if (dims == std::vector<std::size_t>{0}) {
} else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {

Check warning on line 474 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L474

Use operator && instead of and [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:474:  Use operator && instead of and  [readability/alt_tokens] [2]
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
return true;
}
}
}
return true;
}
}
else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
} else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
}
if (dims == std::vector<std::size_t>{0}) {
return true;
}
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
} else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
} else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
} else if (optype == "Unsqueeze" or optype == "Squeeze") {

Check warning on line 512 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc#L512

Use operator || instead of or [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc:512:  Use operator || instead of or  [readability/alt_tokens] [2]
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
return true;
}
}

// Op doesn't fall into known any of unsupported modes.
return false;
// Op doesn't fall into known any of unsupported modes.
return false;
}

void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector<std::vector<NodeIndex>>& clusters,
Expand Down
Loading
Loading