Skip to content

Commit

Permalink
Ensure we support all inputs for MatMulInteger and ConvInteger. Limit… (
Browse files Browse the repository at this point in the history
#52)

* Ensure we support all inputs for MatMulInteger and ConvInteger. Limit to int8 for now

Allow for models with biases/full input and only check for int8 support in EP

* Add support for uint8 types

---------

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and Ted Themistokleous authored Aug 21, 2024
1 parent 74b5fa4 commit c695953
Showing 1 changed file with 134 additions and 134 deletions.
268 changes: 134 additions & 134 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,22 +322,14 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "ConvInteger") {
if (node->InputDefs()[0]->Shape()->dim_size() != 4) {
return true;
}

// migraphx can handle only two inputs
if (node->InputDefs().size() != 2) {
return true;
}

// only support int8 type
// only support int8 and uint8 type
const auto& input_type = node->InputDefs()[0]->TypeAsProto();
if (input_type == nullptr) {
return true;
}

if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
} else if (optype == "Expand") {
Expand Down Expand Up @@ -379,161 +371,169 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co
return true;
}
} else if (optype == "MatMulInteger") {
// migraphx can handle only two inputs
if (node->InputDefs().size() != 2) {
return true;
}

// only support int8 type
// only support int8 and uint8 type
const auto& input_type = node->InputDefs()[0]->TypeAsProto();
if (input_type == nullptr) {
return true;
}

if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
return true;
}
} else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
if ((input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) and
(input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)) {
return true;
}
} else if (optype == "OneHot") {
}
}
else if (optype == "NonZero") {
if (!canEvalNodeArgument(graph_viewer, node, {0}, input_nodes)) {
return true;
}
}
else if (optype == "OneHot") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
} else if (optype == "Pad") {
const auto& args = node->InputDefs();
// if pad size is not constant, migraphx cannot support
if (args.size() >= 2) {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
}

const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}
const auto& attributes = node->GetAttributes();
// Pad only support constant mode
auto mode_attr = attributes.find("mode");
std::string mode = "constant";
if (mode_attr != attributes.end()) {
mode = (*mode_attr).second.s();
}
static const std::set<std::string> allowed_modes = {"constant", "reflect"};
if (allowed_modes.count(mode) == 0) {
return true;
}

// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
return true;
}
// input value only applied to constant mode
if (mode == "constant") {
if (args.size() == 3) {
if (!canEvalNodeArgument(graph_viewer, node, {2}, input_nodes)) {
return true;
}
}
} else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
else if (optype == "Range") {
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
} else if (optype == "Reshape") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
}
else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
return true;
}
} else if (optype == "Resize" or optype == "Upsample") {
const auto& attributes = node->GetAttributes();
auto ct_attr = attributes.find("coordinate_transformation_mode");
if (ct_attr != attributes.end()) {
auto ct = (*ct_attr).second.s();
if (ct == "tf_crop_and_resize") {
return true;
}
}

auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
return true;
}
}
}

} else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
auto mode_attr = attributes.find("mode");
if (mode_attr != attributes.end()) {
auto mode = (*mode_attr).second.s();
if (mode == "cubic") {
return true;
}
} else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}
}
else if (optype == "ReduceSum") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
}
else if (optype == "Slice") {
// MIGraphX does not properly handle the situation where any
// value of the "starts" attribute is higher than a corresponding
// value in the "ends"
auto arg_num = node->InputDefs().size();
std::vector<std::size_t> vec(arg_num);
std::iota(vec.begin(), vec.end(), 0);
vec.erase(vec.begin());
if (!canEvalNodeArgument(graph_viewer, node, vec, input_nodes)) {
return true;
}

const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
return true;
}
}
}
} else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
}
if (dims == std::vector<std::size_t>{0}) {
const auto& attributes = node->GetAttributes();
if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
auto starts = toVector((*attributes.find("starts")).second.ints());
auto ends = toVector((*attributes.find("ends")).second.ints());
for (std::size_t i = 0; i < starts.size(); ++i) {
if (starts.at(i) > ends.at(i)) {
return true;
}
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
}
else if (optype == "Split") {
// cannot process input dim of 0 size
const auto arg_s = node->InputDefs()[0]->Shape();
if (arg_s != nullptr) {
const auto& tensor_dims = arg_s->dim();
std::vector<std::size_t> dims;
for (auto&& dim : tensor_dims) {
dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0);
}
} else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
if (dims == std::vector<std::size_t>{0}) {
return true;
}
} else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}

const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
} else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
return true;
}
}
else if (optype == "Tile") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "TopK") {
if (!canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return true;
}
}
else if (optype == "Unsqueeze" or optype == "Squeeze") {
const auto& args = node->InputDefs();
if (args.size() == 2) {
if (canEvalNodeArgument(graph_viewer, node, {1}, input_nodes)) {
return false;
}
return true;
}
}

// Op doesn't fall into known any of unsupported modes.
return false;
// Op doesn't fall into known any of unsupported modes.
return false;
}

void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector<std::vector<NodeIndex>>& clusters,
Expand Down

0 comments on commit c695953

Please sign in to comment.