From 8ea4371124030fc0ca8af8b407c5f165e6f15bec Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 11 Oct 2024 21:35:34 +0000 Subject: [PATCH] Add simplified layer norm ops --- .../core/providers/migraphx/migraphx_execution_provider.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 15a23aa97ef36..92de01bc8970d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -233,6 +233,8 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: @@ -906,8 +908,10 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "Shape", "Sigmoid", "Sign", + "SimplifiedLayerNormalization", "Sin", "Sinh", + "SkipSimplifiedLayerNormalization", "Slice", "Softmax", "SoftmaxCrossEntropyLoss",