Skip to content

Commit

Permalink
Adds Some Quantization Support for AtenMatmulOp (llvm#3147)
Browse files Browse the repository at this point in the history
1. onnx.MatMulInteger now converts to aten.matmul instead of aten.mm
2. aten.matmul, for ranks >=2, now allows quantized inputs and will
lower to linalg::quantized_matmul or linalg::quantized_batch_matmul.
3. added AtenMatmulOp to the FuseQuantizeOps rewrite patters
QuantizeOperands, QuantizeTransposedOperands, and QuantizeAccumulator
4. added several tests, including some to test AtenMmOp with varying
quantization signed-ness.
5. a quantized matmul mat-vec test is added to verify the failure to
lower to linalg; cleaned of out-of-date code related to common
torch-mlir lowering xfails.
6. in debugging a real model with quantized matmuls, I found a bug on
the scalarize-shapes pass which resulted from the aten.full op folder
returning an incompatible result type. This is fixed by the small change
here to
[lib/Dialect/Torch/IR/TorchOps.cpp](https://github.com/llvm/torch-mlir/compare/main...zjgarvey:torch-mlir:MatMulIntegerFix?expand=1#diff-dc8ed165c207918e606490eee3984b1ad51d7034e6aac36fc046bf47f6f03f4f).
  • Loading branch information
zjgarvey authored Apr 15, 2024
1 parent 5708ee7 commit 5e564b5
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 65 deletions.
10 changes: 5 additions & 5 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();

auto lhsTy = dyn_cast<Torch::ValueTensorType>(lhs.getType());
auto rhsTy = dyn_cast<Torch::ValueTensorType>(rhs.getType());

if (binder.tensorOperandAtIndex(lhsZp, 2)) {
lhsZp = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand All @@ -323,9 +326,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
}

auto lhsTy = dyn_cast<Torch::ValueTensorType>(lhs.getType());
auto rhsTy = dyn_cast<Torch::ValueTensorType>(rhs.getType());

if (auto zpTy = dyn_cast<Torch::ValueTensorType>(lhsZp.getType())) {
for (auto dim : zpTy.getSizes())
if (dim != 1)
Expand Down Expand Up @@ -366,8 +366,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rhs = rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
binder.getLoc(), rhsQTy, rhs, scale, rhsZp);

rewriter.replaceOpWithNewOp<Torch::AtenMmOp>(binder.op, resultType, lhs,
rhs);
rewriter.replaceOpWithNewOp<Torch::AtenMatmulOp>(binder.op, resultType,
lhs, rhs);
return success();
});
patterns.onOp("Mul", 7,
Expand Down
163 changes: 121 additions & 42 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ static void getZeroPoint(Value value, Value &zeropoint) {
}
}

// for uint8 types, we shift down by 128 so that we can faithfully
// represent the quantization with signed i8 types.
static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
Value &zp, bool isUnsignedType, int64_t numBits) {
if (!isUnsignedType)
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{arg},
arg.getType().cast<TensorType>().getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result =
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
b.create<linalg::YieldOp>(loc, result);
});
}

static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
PatternRewriter &rewriter) {
auto valueTy = value.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -154,32 +174,12 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), rhsZeroPoint);

// for uint8 types, we shift down by 128 so that we can faithfully
// represent the quantization with signed i8 types.
auto signShift = [&](Value &arg, Value &zp, bool isUnsignedType,
int64_t numBits) {
if (!isUnsignedType)
return;
int64_t minSI = -std::pow(2, numBits - 1);
Value minSIValue =
rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, ValueRange{arg},
arg.getType().cast<TensorType>().getElementType(),
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value result = rewriter.create<arith::AddIOp>(loc, payloadArgs[0],
minSIValue);
b.create<linalg::YieldOp>(loc, result);
});
};

// change uint8 quantization -> int8 quantization
int64_t numBits =
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(lhs, lhsZeroPoint, isUnsigned, numBits);
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(rhs, rhsZeroPoint, isUnsignedR, numBits);
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);

matmul =
rewriter
Expand Down Expand Up @@ -302,10 +302,57 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
auto lhsType = lhs.getType().cast<RankedTensorType>();
auto rhsType = rhs.getType().cast<RankedTensorType>();

auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());

// Get the rank of both matrix.
unsigned lhsRank = lhsType.getRank();
unsigned rhsRank = rhsType.getRank();

Value lhsZeroPoint, rhsZeroPoint;
getZeroPoint(op.getSelf(), lhsZeroPoint);
getZeroPoint(op.getOther(), rhsZeroPoint);

if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
return rewriter.notifyMatchFailure(
op, "unsupported: aten.matmul with mixed quantization");
}

bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);

if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
// Allows quantized types to mismatch
return rewriter.notifyMatchFailure(
op, "unsupported: aten.matmul with different input element types");
}

if (lhsZeroPoint) {
if (lhsRank < 2 || rhsRank < 2) {
return rewriter.notifyMatchFailure(
op, "unsupported: quantized aten.mm with vector");
}
lhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(lhsZeroPoint.getType()),
lhsZeroPoint);
rhsZeroPoint = typeConverter->materializeTargetConversion(
rewriter, loc,
getTypeConverter()->convertType(rhsZeroPoint.getType()),
rhsZeroPoint);
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), lhsZeroPoint);
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), rhsZeroPoint);

// change uint8 quantization -> int8 quantization
int64_t numBits =
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
}

Type newResultType = getTypeConverter()->convertType(op.getType());
auto resultType = cast<RankedTensorType>(newResultType);
Type elementType = resultType.getElementType();
Expand Down Expand Up @@ -366,7 +413,7 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
return success();
}

// Fourth Case: Vec-Vec Multiplication.
// Fourth Case: Mat-Mat Multiplication.
if (lhsRank == 2 && rhsRank == 2) {
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Expand All @@ -376,11 +423,20 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {

Value zeroTensor = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value matmul =
rewriter
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
Value matmul;
if (lhsZeroPoint) {
matmul = rewriter
.create<linalg::QuantizedMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
zeroTensor)
.getResult(0);
} else {
matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroTensor.getType(),
ValueRange{lhs, rhs}, zeroTensor)
.getResult(0);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
Expand Down Expand Up @@ -465,12 +521,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
rewriter, loc,
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
elementType);
Value matmul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
Value matmul;
if (lhsZeroPoint) {
matmul = rewriter
.create<linalg::QuantizedBatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs,
lhsZeroPoint, rhsZeroPoint},
zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
matmul);
return success();
}
matmul = rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
Expand Down Expand Up @@ -519,13 +587,24 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
loc, rewriter.getZeroAttr(elementType));
Value zeroTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);

Value batchMatMul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
.getResult(0);
Value batchMatMul;

if (lhsZeroPoint) {
batchMatMul = rewriter
.create<linalg::QuantizedBatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs,
lhsZeroPoint, rhsZeroPoint},
zeroTensor)
.getResult(0);
} else {
batchMatMul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
.getResult(0);
}
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
loc, resultType, batchMatMul, reassociation);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4049,8 +4049,8 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
if (sz == Torch::kUnknownSize || sz < 0)
return nullptr;

ShapedType shapedty =
mlir::RankedTensorType::get(sizes, resultTensorType.getDtype());
ShapedType shapedty = mlir::RankedTensorType::get(
resultTensorType.getSizes(), resultTensorType.getDtype());

auto elementType = shapedty.getElementType();
if (isa<IntegerType>(elementType)) {
Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,11 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMmOp>, QuantizeTransposedOperands<AtenMmOp>,
QuantizeAccumulator<AtenMmOp>, QuantizeBias<AtenConvolutionOp>>(
context);
QuantizeOperands<AtenMatmulOp>,
QuantizeTransposedOperands<AtenMatmulOp>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
QuantizeBias<AtenConvolutionOp>>(context);

GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
Expand Down
14 changes: 14 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
'AtenIntBoolOpModule_basic',
'QuantizedMLP_basic',
'QuantizedSingleLayer_basic',
'QuantizedBatchedInputSingleLayer_basic',
'QuantizedNoLayer_basic',
'ScalarImplicitFloatModule_basic',
'ScalarImplicitIntModule_basic',
Expand Down Expand Up @@ -317,6 +318,12 @@
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"AtenMmQuint8_basic",
"AtenMmQint8_basic",
"AtenMmQMixedSigni8_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8_basic",
"Conv2dQInt8Module_basic",

# Dynamo not supporting conv_tbc
Expand Down Expand Up @@ -1802,6 +1809,7 @@
"NeIntModule_basic",
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
"QuantizedBatchedInputSingleLayer_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SliceEndSleStartModule_basic",
Expand Down Expand Up @@ -1952,7 +1960,13 @@
"AtenIntTensorCharDtypeModule_basic",
"AtenItemFpOpModule_basic",
"AtenItemIntOpModule_basic",
"AtenMmQint8_basic",
"AtenMmQuint8_basic",
"AtenMmQMixedSigni8_basic",
"AtenMatmulQMixedSigni8Transpose_basic",
"AtenMatmulQMixedSigni8_basic",
"AtenMatmulQint8MV_basic",
"AtenMatmulQint8_basic",
"AtenRealView128Module_basic",
"AtenRealView64Module_basic",
"AtenSubFloatModule_basic",
Expand Down
10 changes: 2 additions & 8 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,13 @@
"NativeGroupNormBackwardModule_basic",
"QuantizedMLP_basic",
"QuantizedSingleLayer_basic",
"QuantizedBatchedInputSingleLayer_basic",
"AtenMatmulQint8MV_basic",
"ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
}

# TODO: Delete once torch 2.1.0 is released
if torch_version_for_comparison() < version.parse("2.1.0.dev"):
COMMON_TORCH_MLIR_LOWERING_XFAILS.update({
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionSameModule_basic"
})


def register_all_tests():
"""Registers all the built-in E2E tests that Torch-MLIR provides."""
# Side-effecting import statements.
Expand Down
Loading

0 comments on commit 5e564b5

Please sign in to comment.