Skip to content

Commit

Permalink
Support signed and unsigned integer types in migraphx dialect (#1692)
Browse files Browse the repository at this point in the history
* Support signed and unsigned integer types in migraphx dialect

* fix realize-int4.mlir tests and add more tests

* Fix some bugs (quantize and custom ops) and add more tests.

* Add assert to check we don't convert from unsigned to signed or signed
to unsigned

* Add more tests

* Fix uint division test
  • Loading branch information
dhernandez0 authored Nov 8, 2024
1 parent 290cd49 commit 99fc9d2
Show file tree
Hide file tree
Showing 14 changed files with 600 additions and 135 deletions.
24 changes: 8 additions & 16 deletions mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def MIGraphX_ClipOp :
// Keep that logic here.
def MIGraphX_WhereOp :
MIGraphX_Op<"where">,
Arguments<(ins MIXRShapedOf<[I8]>:$cond,
Arguments<(ins MIXRShapedOf<[I8, SI8, UI8]>:$cond,
AnyMIXRShaped:$inA,
AnyMIXRShaped:$inB)>,
Results<(outs AnyMIXRShaped:$output)> {
Expand All @@ -117,18 +117,14 @@ def MIGraphX_WhereOp :

def MIGraphX_ConvertOp :
MIGraphX_Op<"convert">,
Arguments<(ins AnyMIXRShaped:$inA, UnitAttr:$zeroExtend)>,
Arguments<(ins AnyMIXRShaped:$inA)>,
Results<(outs AnyMIXRShaped:$output)> {
let summary = "Elementwise type conversion";
let description = [{
Type conversion. Due to impedance mismatches between MIGraphX and Tosa,
currently only supports float to float conversions

If zeroExtend is set, the input is treated as an unsigned integer.
This is MLIR-specific, since MIGraphX encodes integer signedness in types
but MLIR generally uses signless integers.
}];
let assemblyFormat = "(`zero_extend` $zeroExtend^)? $inA attr-dict `:` type($inA) `to` type($output)";
let assemblyFormat = "$inA attr-dict `:` type($inA) `to` type($output)";
}

class MIGraphX_ElementwiseUnaryOp<string name, list<Trait> traits=[]> :
Expand Down Expand Up @@ -181,10 +177,9 @@ def MIGraphX_TanhOp :

// int4 operations
def MIGraphX_UnpackOp : MIGraphX_Op<"unpack">,
Arguments<(ins MIXRShapedOf<[I8, I<4>]>:$in,
I64Attr:$axis,
BoolAttr:$isUnsigned)>,
Results<(outs MIXRShapedOf<[I8, I<4>]>:$out)> {
Arguments<(ins MIXRShapedOf<[I8, UI8, SI8, I<4>, SI<4>, UI<4>]>:$in,
I64Attr:$axis)>,
Results<(outs MIXRShapedOf<[I8, UI8, SI8, I<4>, SI<4>, UI<4>]>:$out)> {
let summary = "Unpack int4 vaules stored as bytes";
let description = [{
Given a shaped tensor of bytes, double the length of `axis` by
Expand All @@ -201,9 +196,6 @@ def MIGraphX_UnpackOp : MIGraphX_Op<"unpack">,
the corresponding tensor of i8 (in which case, the `i4` are exposed as an
extra dimension and not flattened) or another tensor of i4. This allows us to
progressively move unpack up to function boundaries.

If `isUnsigned` is true, the inputs are a buffer of unsigned 4-bit ints,
otherwise they are signed.
}];

let assemblyFormat = [{ $in attr-dict `:` type($in) `->` type($out) }];
Expand Down Expand Up @@ -333,7 +325,7 @@ class MIGraphX_ConvOpBase<string mnemonic, list<Type> inputTypes=[], list<Type>
}

def MIGraphX_QuantConvolutionOp :
MIGraphX_ConvOpBase<"quant_convolution", [F8E4M3FNUZ, F8E5M2FNUZ, F8E5M2, F8E4M3FN, I8], [F32, I32]> {
MIGraphX_ConvOpBase<"quant_convolution", [F8E4M3FNUZ, F8E5M2FNUZ, F8E5M2, F8E4M3FN, I8, SI8], [F32, I32, SI32]> {
let summary = "quantized convolution forward";
let description = [{
The `migraphx.quant_convolution` op computes quantized convolution forward.
Expand Down Expand Up @@ -510,7 +502,7 @@ class MIGraphX_DotOpBase<string mnemonic, list<Type> inputTypes=[], list<Type> o
}

def MIGraphX_QuantDotOp :
MIGraphX_DotOpBase<"quant_dot", [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2, I8], [F32, I32]>{
MIGraphX_DotOpBase<"quant_dot", [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2, I8, SI8], [F32, I32, SI32]>{
let summary = "Dot product of quantized tensors";
let description = [{
The `migraphx.quant_dot` op computes the dot product of two tensors.
Expand Down
Loading

0 comments on commit 99fc9d2

Please sign in to comment.