From 36cd087138b1013b41d4791b2ba33ef0ceae78e0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 7 Sep 2024 03:43:29 +0000 Subject: [PATCH] Add support for softmaxcrossentropy loss to MIGraphX EP --- .../providers/migraphx/migraphx_execution_provider.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f7e7299400058..737372c2a4fb9 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -907,6 +907,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Sinh", "Slice", "Softmax", + "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "SpaceToDepth", @@ -1018,15 +1019,6 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v return result; } - // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, - // so if a model contain any of these operators, fall back to CPU - std::unordered_set vec_ops = {"SoftmaxCrossEntropyLoss"}; - if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { - return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); - })) { - return result; - } - auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); // check whether a subgrap should fallback to CPU