Skip to content

Commit

Permalink
[torch] Adds Quantization Support for aten.relu (llvm#3177)
Browse files Browse the repository at this point in the history
A choice was made to quantize the return type of Relu with a scale and
zero point copied from the input's quantization scheme. With this
choice, the torch-to-linalg conversion of quantized Relu essentially
computes max(input, zeroPoint) in the elementwise payload.
  • Loading branch information
zjgarvey authored Apr 23, 2024
1 parent 09d4204 commit a8ba865
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 15 deletions.
78 changes: 67 additions & 11 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
llvm_unreachable("Unhandled element type for comparison");
}

static Value getZeroPoint(Value value) {
if (auto make = value.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
return make.getZeroPoint();
}
return nullptr;
}

static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
Value lhs, Value rhs) {
return createComparisonTemplate<arith::CmpFPredicate::OGT,
Expand Down Expand Up @@ -528,19 +535,68 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy);
}
if (auto relu = dyn_cast<AtenReluOp>(op)) {
if (!relu.getType()
.cast<ValueTensorType>()
.getDtype()
.isa<mlir::FloatType>()) {
relu.emitError("unimplemented: non-floating point dtype");
Value zeroPoint = getZeroPoint(relu.getSelf());
Value arg = payloadArgs[0];
auto intType = arg.getType().dyn_cast<mlir::IntegerType>();
if (zeroPoint && !intType) {
relu.emitError("unimplemented: non-integer quantized Relu.");
return nullptr;
}
Type elementType = payloadArgs[0].getType();
Value constZero =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], constZero);
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], constZero);
auto reluTorchType = cast<ValueTensorType>(relu.getType());
bool isUnsigned =
torch_to_linalg::isUnsignedTorchType(reluTorchType.getDtype());
if (zeroPoint) {
int64_t zeroPointInt;
int64_t width = intType.getWidth();
assert(width < 64);
int64_t minForIntType = isUnsigned ? 0 : -(1 << (width - 1));
int64_t maxForIntType =
isUnsigned ? (1 << (width + 1)) - 1 : (1 << (width - 1)) - 1;
// check for constant zero point edge-cases:
if (matchPattern(zeroPoint, m_TorchConstantInt(&zeroPointInt))) {
if (zeroPointInt > maxForIntType) {
// TODO: figure out how to handle this case:
// current impl. quantizes output like input.
// If zero point > maxForIntType, ordinary relu should return 0.
// However, 0 isn't represented in such a quantization scheme.
relu.emitError(
"unimplemented: quantized relu for zero-point > max qint");
return nullptr;
}
if (zeroPointInt < minForIntType)
return arg;
}
zeroPoint = converter->materializeTargetConversion(
b, loc, converter->convertType(zeroPoint.getType()), zeroPoint);
auto minForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType));
auto maxForIntTypeValue = b.create<arith::ConstantOp>(
loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType));
auto zpLtMax = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, maxForIntTypeValue);
b.create<cf::AssertOp>(
loc, zpLtMax,
b.getStringAttr("Invalid Quantization: quantized relu with "
"zero-point > max qint"));
auto zpLtMin = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
zeroPoint, minForIntTypeValue);
zeroPoint = b.create<arith::SelectOp>(loc, zpLtMin, minForIntTypeValue,
zeroPoint);
zeroPoint = b.create<arith::TruncIOp>(loc, arg.getType(), zeroPoint);
} else {
zeroPoint =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(arg.getType()));
}
Value cmp;
if (intType) {
auto pred =
isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt;
cmp = b.create<arith::CmpIOp>(loc, pred, arg, zeroPoint);
} else {
cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, arg,
zeroPoint);
}
return b.create<arith::SelectOp>(loc, cmp, arg, zeroPoint);
}
if (auto round = dyn_cast<AtenRoundOp>(op)) {
if (!round.getType()
Expand Down
81 changes: 77 additions & 4 deletions lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ using namespace mlir::torch::Torch;

namespace {

template <typename SrcOp> struct QuantInfo {
static constexpr unsigned operandsToQuantize[2] = {0, 1};
};

template <> struct QuantInfo<AtenReluOp> {
static constexpr unsigned operandsToQuantize[1] = {0};
};
template <typename SrcOp>
class QuantizeOperands : public OpRewritePattern<SrcOp> {
public:
Expand All @@ -42,8 +49,9 @@ class QuantizeOperands : public OpRewritePattern<SrcOp> {
return operand;
};

operands[0] = f(operands[0]);
operands[1] = f(operands[1]);
for (unsigned i : QuantInfo<SrcOp>::operandsToQuantize) {
operands[i] = f(operands[i]);
}

if (!dequanted) {
return rewriter.notifyMatchFailure(op, "no dequantizations found");
Expand Down Expand Up @@ -259,6 +267,70 @@ class QuantizeAccumulator : public OpRewritePattern<SrcOp> {
}
};

// Use for ops which do not manipulate scale/zero point of an input.
template <typename SrcOp>
class QuantizeResultLikeOperand : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<Value> operands(op->getOperands());
Value input = operands[0];

auto inputType = dyn_cast_or_null<ValueTensorType>(input.getType());
if (!inputType || !inputType.hasDtype())
return failure();
auto qDtype = inputType.getDtype();

auto resultTy = dyn_cast_or_null<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype())
return failure();

Type resultETy = resultTy.getDtype();
if (!isa<mlir::FloatType>(resultETy))
return failure();

Value inputScale, inputZeroPoint;
Type definingOpInputType;
if (auto defining = input.template getDefiningOp<
Aten_MakePerTensorQuantizedTensorOp>()) {
inputScale = defining.getScale();
inputZeroPoint = defining.getZeroPoint();
definingOpInputType = defining.getSelf().getType();
}

auto inputIntReprType =
dyn_cast_or_null<ValueTensorType>(definingOpInputType);
if (!inputScale || !inputZeroPoint || !inputIntReprType ||
!inputIntReprType.hasDtype())
return failure();
auto intReprDtype = inputIntReprType.getDtype();

// set SrcOp type to use quantized dtype from input
auto newResultTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
auto newResult = rewriter.create<SrcOp>(op.getLoc(), newResultTy, operands);

// int repr to get non quantized int type result
auto intReprTy = rewriter.getType<ValueTensorType>(
resultTy.getOptionalSizes(), intReprDtype);
auto intRepr =
rewriter.create<AtenIntReprOp>(op.getLoc(), intReprTy, newResult);

// requantize so the scale and zero-point info can be attached
auto quantTy =
rewriter.getType<ValueTensorType>(resultTy.getOptionalSizes(), qDtype);
auto quant = rewriter.create<Aten_MakePerTensorQuantizedTensorOp>(
op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint);

// dequant back to original dtype
auto dequant =
rewriter.create<AtenDequantizeTensorOp>(op.getLoc(), resultTy, quant);
rewriter.replaceOp(op, dequant);
return success();
}
};

template <typename SrcOp> class RemoveUnused : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
Expand All @@ -285,11 +357,12 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase<FuseQuantizedOpsPass> {
RemoveUnused<AtenQuantizePerTensorOp>,
RemoveUnused<Aten_MakePerTensorQuantizedTensorOp>,
RemoveUnused<AtenTransposeIntOp>, QuantizeOperands<AtenConvolutionOp>,
QuantizeOperands<AtenMatmulOp>,
QuantizeOperands<AtenMatmulOp>, QuantizeOperands<AtenReluOp>,
QuantizeTransposedOperands<AtenMatmulOp>,
QuantizeAccumulator<AtenMatmulOp>, QuantizeOperands<AtenMmOp>,
QuantizeTransposedOperands<AtenMmOp>, QuantizeAccumulator<AtenMmOp>,
QuantizeBias<AtenConvolutionOp>>(context);
QuantizeResultLikeOperand<AtenReluOp>, QuantizeBias<AtenConvolutionOp>>(
context);

GreedyRewriteConfig config;
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
Expand Down
9 changes: 9 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@
"AtenMatmulQint8VV_basic",
"AtenMatmulQint8VM_basic",
"AtenMatmulQint8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",

# Dynamo not supporting conv_tbc
Expand Down Expand Up @@ -413,6 +416,9 @@
'AtenMmQMixedSigni8_basic',
'AtenMmQint8_basic',
'AtenMmQuint8_basic',
"QuantizedReluInt32_basic",
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
'AtenSubFloatModule_basic',
'BincountMinlengthModule_basic',
'BincountModule_basic',
Expand Down Expand Up @@ -2466,6 +2472,9 @@
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"QuantizedReluInt8_basic",
"QuantizedReluInt32_basic",
"QuantizedReluUint8_basic",
"RandIntDtypeModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
Expand Down
63 changes: 63 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,69 @@ def ElementwiseReluModule_basic(module, tu: TestUtils):

# ==============================================================================

class QuantizedReluInt8(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25)
qx = torch.dequantize(qx)
return torch.relu(qx)

@register_test_case(module_factory=lambda: QuantizedReluInt8())
def QuantizedReluInt8_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=-128, high=127).to(torch.int8))

# ==============================================================================

class QuantizedReluUint8(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.uint8, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
qx = torch.dequantize(qx)
return torch.relu(qx)

@register_test_case(module_factory=lambda: QuantizedReluUint8())
def QuantizedReluUint8_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=0, high=255).to(torch.uint8))

# ==============================================================================

class QuantizedReluInt32(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, 190)
qx = torch.dequantize(qx)
return torch.relu(qx)

@register_test_case(module_factory=lambda: QuantizedReluInt32())
def QuantizedReluInt32_basic(module, tu: TestUtils):
module.forward(tu.randint(7, 4, low=(-2**31), high=(2**31 - 1)).to(torch.int32))

# ==============================================================================


class ElementwiseRelu6Module(torch.nn.Module):

Expand Down

0 comments on commit a8ba865

Please sign in to comment.