Skip to content

Commit

Permalink
Add int4 type support for MIGraphX
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Oct 11, 2024
1 parent a55b735 commit e3933cb
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ static bool IsTypeSupported(const NodeArg* node_arg) {
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
Expand Down Expand Up @@ -264,6 +266,12 @@ static bool getMIGraphXType(ONNXTensorElementDataType type,
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
mgx_type = migraphx_shape_int8_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4:
mgx_type = migraphx_shape_uint4_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4:
mgx_type = migraphx_shape_int4_type;
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
mgx_type = migraphx_shape_int16_type;
break;
Expand Down

0 comments on commit e3933cb

Please sign in to comment.