From 7076f7463507e72f1fa01f81ef3ebb33e21ab4b8 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 4 Dec 2024 23:05:49 +0000 Subject: [PATCH] Add support for int4 inputs Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands --- .../providers/migraphx/migraphx_execution_provider.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f320bf61f0ddf..c1cae43480ea2 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -237,10 +237,12 @@ static bool IsTypeSupported(const NodeArg* node_arg) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -277,6 +279,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: mgx_type = migraphx_shape_fp8e5m2fnuz_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -289,6 +294,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break;