diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 737372c2a4fb9..f3c0a23ec2ad6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -785,148 +785,11 @@ static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { - static std::set mgx_supported_ops = {"Abs", - "Acos", - "Acosh", - "Add", - "And", - "ArgMax", - "ArgMin", - "Asin", - "Asinh", - "Atan", - "Atanh", - "ATen", - "AveragePool", - "BatchNormalization", - "Cast", - "Ceil", - "Celu", - "Clip", - "Concat", - "Constant", - "ConstantFill", - "ConstantOfShape", - "Conv", - "ConvInteger", - "ConvTranspose", - "Cos", - "Cosh", - "CumSum", - "DepthToSpace", - "DequantizeLinear", - "Div", - "Dropout", - "Elu", - "Equal", - "Erf", - "Exp", - "Expand", - "EyeLike", - "Flatten", - "Floor", - "GRU", - "Gather", - "GatherElements", - "GatherND", - "Gemm", - "GlobalAveragePool", - "GlobalMaxPool", - "Greater", - "GreaterOrEqual", - "HardSigmoid", - "HardSwish", - "Identity", - "If", - "ImageScaler", - "InstanceNormalization", - "IsNan", - "LeakyRelu", - "Less", - "LessOrEqual", - "Log", - "LogSoftmax", - "Loop", - "LpNormalization", - "LRN", - "LSTM", - "MatMul", - "MatMulInteger", - "Max", - "MaxPool", - "Mean", - "Min", - "Mod", - "Mul", - "Multinomial", - "Neg", - "NonMaxSuppression", - "NonZero", - "Not", - "OneHot", - "Or", - "Pad", - "Pow", - "PRelu", - "QLinearAdd", - "QLinearConv", - "QLinearMatMul", - "QuantizeLinear", - "DynamicQuantizeLinear", - "RandomNormal", - "RandomNormalLike", - "RandomUniform", - "RandomUniformLike", - "Range", - "Reciprocal", - "ReduceL1", - "ReduceL2", - "ReduceLogSum", - "ReduceLogSumExp", - "ReduceMax", - "ReduceMean", - "ReduceMin", - "ReduceProd", - "ReduceSum", - "ReduceSumSquare", - "Relu", - "Reshape", - "Resize", - "ReverseSequence", - "RNN", - "Roialign", - "Round", - "Scatter", - "ScatterElements", - "ScatterND", - "Selu", - "Shape", - "Sigmoid", - "Sign", - "Sin", - "Sinh", - "Slice", - "Softmax", - "SoftmaxCrossEntropyLoss", - "Softplus", - "Softsign", - "SpaceToDepth", - "Split", - "Sqrt", - "Squeeze", - "Sub", - "Sum", - "Tan", - "Tanh", - "ThresholdedRelu", - "Tile", - "TopK", - "Transpose", - "Trilu", - "Unsqueeze", - "Upsample", - "Where", - "Xor"}; + + // Leverage MIGraphX API to tell is which operators we support. + auto migx_ops{migraphx::get_onnx_operators()}; + static std::set mgx_supported_ops(std::make_move_iterator(migx_ops.begin()), std::make_move_iterator(migx_ops.end())); + std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) {