Skip to content

Commit

Permalink
Fix some bugs for ReduceMin/Max (#3038)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jan 10, 2025
1 parent 8c54c0c commit 0183ad9
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 155 deletions.
84 changes: 38 additions & 46 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,49 +953,29 @@ bool isSuitableForZDNN<ONNXGemmOp>(
// Common function for ReduceMax and ReduceMin
template <typename OP_TYPE>
static bool checkReduceParam(OP_TYPE op) {
IndexExprBuilderForAnalysis createIE(op.getLoc());
OpBuilder b(op);
Location loc = op.getLoc();
IndexExprBuilderForAnalysis createIE(loc);
IndexExprScope ieScope(&b, loc);

Value data = op.getData();
Value axesVal = op.getAxes();
int64_t keepdims = op.getKeepdims();
int64_t noop_with_empty_axes = op.getNoopWithEmptyAxes();

// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPALevel::M15))
return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);

// Check data type.
Value data = op.getData();
int64_t rank = getRank(data.getType());
if (!isValidElementTypeAndRank(op.getOperation(), data))
return false;

// Check axes value
Value axesVal = op.getAxes();
if (!isDenseONNXConstant(axesVal))
return false;

ONNXConstantOp axesConstant =
mlir::cast<ONNXConstantOp>(axesVal.getDefiningOp());
int64_t axesInt = getScalarValue<int64_t>(axesConstant);

int64_t keepdims = op.getKeepdims();
int64_t noop_with_empty_axes = op.getNoopWithEmptyAxes();
int64_t rank = createIE.getShapedTypeRank(data);

// Check if axes (int64) is exactly a size of one
if (floor(log10(axesInt)) + 1 == 1) {
int64_t axis = axesInt;
// Accepted range is [-r, r-1] where r = rank(data)
if (axis < -rank || axis > rank - 1) {
std::string message =
"The `axis` is out of the accepted range which is [-r, r-1]";
return onnxToZHighUnsupportedReport(op, message);
}
if ((axis != -1) && (axis != rank - 1)) {
std::string message = "The `axis` must be the innermost dimension. ";
return onnxToZHighUnsupportedReport(op, message);
}
} else {
std::string message = "Axes can only be a scalar size of one. ";
return onnxToZHighUnsupportedReport(op, message);
}

// REMINDER: Should we check the input tensor rank.
// NNPA does not support reduction over all axes.
if (isNoneValue(axesVal))
return onnxToZHighUnsupportedReport(
op.getOperation(), "Does not support reduction over all axes.");

// Check keepdims and noop_with_empty_axes, we only support the default
// value. Attributes: keepdims (default is 1) and noop_with_empty_axes
Expand All @@ -1007,33 +987,45 @@ static bool checkReduceParam(OP_TYPE op) {
std::to_string(keepdims) + ") must be 1.";
return onnxToZHighUnsupportedReport(op, message);
}

// Check axes value
DimsExpr axesIE;
createIE.getIntFromArrayAsDims(axesVal, axesIE);
if (axesIE.size() != 1)
return onnxToZHighUnsupportedReport(
op.getOperation(), "Does not support multiple reduction axes.");
if (!axesIE[0].isLiteral())
return onnxToZHighUnsupportedReport(
op.getOperation(), "Reduction axis is unknown at compile time.");
int64_t axis = axesIE[0].getLiteral();
// Accepted range is [-r, r-1] where r = rank(data)
if (axis < -rank || axis > rank - 1) {
std::string message =
"Reduction axis is out of the accepted range which is [-r, r-1]";
return onnxToZHighUnsupportedReport(op, message);
}
if ((axis != -1) && (axis != rank - 1)) {
std::string message = "Reduction axis must be the innermost dimension. ";
return onnxToZHighUnsupportedReport(op, message);
}

return true;
}

/// Check legality for ONNXReduceMax.
template <>
bool isSuitableForZDNN<ONNXReduceMaxOp>(
ONNXReduceMaxOp op, const DimAnalysis *dimAnalysis) {

// Check parameter restrictions for ReduceMax
bool isReduceMax = checkReduceParam<ONNXReduceMaxOp>(op);
if (!isReduceMax)
return false;

return true;
return checkReduceParam<ONNXReduceMaxOp>(op);
}

/// Check legality for ONNXReduceMin.
template <>
bool isSuitableForZDNN<ONNXReduceMinOp>(
ONNXReduceMinOp op, const DimAnalysis *dimAnalysis) {

// Check parameter restrictions for ReduceMin
bool isReduceMin = checkReduceParam<ONNXReduceMinOp>(op);
if (!isReduceMin)
return false;

return true;
return checkReduceParam<ONNXReduceMinOp>(op);
}

/// Check legality for ONNXReduceMeanV13.
Expand Down
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def replaceONNXReduceMeanV13Pattern : Pat<
def replaceONNXReduceMaxPattern : Pat<
(ONNXReduceMaxOp:$res $data, $axes, $keepdims, $noop_with_empty_axes),
(ZHighUnstickOp (ZHighReduceMaxOp (ZHighStickOp:$s_x $data, (NoneLayoutAttr),
(GetDefaultSaturation)), (GetStringAttr<"REDUCE_OP_MAXIMUM">))),
(GetDefaultSaturation)))),
[(IsCompatibleWithNNPALevelArch15)]
>;

Expand All @@ -484,7 +484,7 @@ def replaceONNXReduceMaxPattern : Pat<
def replaceONNXReduceMinPattern : Pat<
(ONNXReduceMinOp:$res $data, $axes, $keepdims, $noop_with_empty_axes),
(ZHighUnstickOp (ZHighReduceMinOp (ZHighStickOp:$s_x $data, (NoneLayoutAttr),
(GetDefaultSaturation)), (GetStringAttr<"REDUCE_OP_MINIMUM">))),
(GetDefaultSaturation)))),
[(IsCompatibleWithNNPALevelArch15)]
>;
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 11 additions & 10 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,33 +1123,34 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern {
Value data = operands[0];

// Helper builders.
MultiDialectBuilder<IndexExprBuilderForKrnl, KrnlBuilder, LLVMBuilder,
MemRefBuilder>
MultiDialectBuilder<IndexExprBuilderForKrnl, KrnlBuilder, MemRefBuilder>
create(rewriter, loc);

// Convert ZTensor type to MemRefType.
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());

// Shape helper.
ZHighReduceMaxOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
ZHighReductionOpShapeHelper<OP_TYPE> shapeHelper(
op, operands, &create.krnlIE);
shapeHelper.computeShapeAndAssertOnFailure();
SmallVector<IndexExpr, 4> &dims = shapeHelper.getOutputDims();

// Allocate a buffer for the result MemRef.
Value alloc = insertAllocForZMemRef(
zMemRefType, shapeHelper.getOutputDims(), op, rewriter);

// Get the original shape before it is vanished by lower passes.
Value shape = insertShapeMemRefI64(rewriter, loc, dims);
DimsExpr dataDims;
create.krnlIE.getShapeAsDims(data, dataDims);
Value shape = insertShapeMemRefI64(rewriter, loc, dataDims);

// If set to NULL, the operation will determine, allocate and free storage
// automatically.
Value workArea = create.llvm.null(krnl::getI8PointerType(context));
// Emit 'alloc' for work_area that is of 4K-aligned 8K bytes.
Value workArea = create.mem.alignedAlloc(
MemRefType::get({8 * 1024}, rewriter.getIntegerType(8)), gAlignment);

// Emit a ZLow operation.
rewriter.create<typename ZLowReduceOpFor<OP_TYPE>::Op>(loc, data, workArea,
shape, alloc, zMemRefType.layout, reduceOp.getOpTypeAttr());
rewriter.create<typename ZLowReduceOpFor<OP_TYPE>::Op>(
loc, data, workArea, shape, alloc, zMemRefType.layout);
rewriter.replaceOp(op, alloc);
return success();
}
Expand Down
33 changes: 13 additions & 20 deletions src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,18 +1152,18 @@ class ZLowLeakyReluLowering : public ConvertToLLVMPattern {
};

template <typename REDUCE_OP>
API getReduceAPI() {
return API::NULL_API;
zdnn_reduce_ops getZDNNReduceOpType() {
return REDUCE_OP_MAXIMUM;
}

template <>
API getReduceAPI<ZLowReduceMaxOp>() {
return API::ZDNN_REDUCEMAX;
zdnn_reduce_ops getZDNNReduceOpType<ZLowReduceMaxOp>() {
return REDUCE_OP_MAXIMUM;
}

template <>
API getReduceAPI<ZLowReduceMinOp>() {
return API::ZDNN_REDUCEMIN;
zdnn_reduce_ops getZDNNReduceOpType<ZLowReduceMinOp>() {
return REDUCE_OP_MINIMUM;
}

template <typename REDUCE_OP>
Expand Down Expand Up @@ -1210,8 +1210,8 @@ class ZLowReduceLowering : public ConvertToLLVMPattern {
Value one = create.llvm.constant(llvmI64Ty, static_cast<int64_t>(1));

// Calculation for the output dimension
int64_t axis = dims.size();
std::vector<Value> outputDims;
int64_t axis = dims.size() - 1;
SmallVector<Value, 4> outputDims;
for (int64_t i = 0; i < axis; ++i) {
outputDims.emplace_back(dims[i]);
}
Expand All @@ -1236,20 +1236,13 @@ class ZLowReduceLowering : public ConvertToLLVMPattern {
zTensorHelper.getAlignedI8Ptr(operandAdaptor.getWorkArea());

// op_type
nnpa_reduce_operations opType;
StringRef opTypeStr = reduceOp.getOpType();
if (opTypeStr.equals_insensitive("REDUCE_OP_MINIMUM"))
opType = NNPA_REDUCE_OP_MINIMUM;
else if (opTypeStr.equals_insensitive("REDUCE_OP_MAXIMUM"))
opType = NNPA_REDUCE_OP_MAXIMUM;
else
llvm_unreachable("Unsupported operation type");
Value optype = create.llvm.constant(
rewriter.getI64Type(), static_cast<int64_t>(opType));
zdnn_reduce_ops zdnnOpType = getZDNNReduceOpType<REDUCE_OP>();
Value opType = create.llvm.constant(
rewriter.getI64Type(), static_cast<int64_t>(zdnnOpType));

// Call the zDNN ReduceMax/ReduceMin API.
callApi(rewriter, loc, module, apiRegistry, getReduceAPI<REDUCE_OP>(),
{toOpaquePtr(rewriter, loc, module, inputZTensor.val), workArea, optype,
callApi(rewriter, loc, module, apiRegistry, API::ZDNN_REDUCE,
{toOpaquePtr(rewriter, loc, module, inputZTensor.val), workArea, opType,
toOpaquePtr(rewriter, loc, module, outputZTensor.val)});

rewriter.eraseOp(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
ApiSpec(API::ZDNN_LOG, "zdnn_log_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false),
ApiSpec(API::ZDNN_EXP, "zdnn_exp_ext", int32Ty, {opaquePtrTy, opaquePtrTy}, false),
ApiSpec(API::ZDNN_INVSQRT, "zdnn_invsqrt_ext", int32Ty, {opaquePtrTy, float32Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_REDUCEMAX, "zdnn_reduce_ext", int32Ty, {opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_REDUCEMIN, "zdnn_reduce_ext", int32Ty, {opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_REDUCE, "zdnn_reduce_ext", int32Ty, {opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
// Activation operations
ApiSpec(API::ZDNN_LEAKY_RELU, "zdnn_leaky_relu_ext", int32Ty, {opaquePtrTy, opaquePtrTy, float32Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_RELU, "zdnn_relu_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy}, false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ enum class API {
ZDNN_LOG,
ZDNN_EXP,
ZDNN_INVSQRT,
ZDNN_REDUCEMAX,
ZDNN_REDUCEMIN,
// Reduction operations
ZDNN_REDUCE,
ZDNN_MEANREDUCE2D,
// Activation operations
ZDNN_RELU,
ZDNN_GELU,
Expand All @@ -68,7 +69,6 @@ enum class API {
ZDNN_CONV2D,
ZDNN_AVGPOOL2D,
ZDNN_MAXPOOL2D,
ZDNN_MEANREDUCE2D,
ZDNN_BATCHNORM,
ZDNN_LEAKY_RELU,
// Scalar operations.
Expand Down
14 changes: 6 additions & 8 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,13 @@ def ZHighReduceMaxOp:ZHigh_Op<"ReduceMax", [Pure, SameOperandsAndResultLayout,
ZHigh operation to perform a ReduceMax.
op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM.
}];
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data,
DefaultValuedStrAttr<StrAttr, "REDUCE_OP_MAXIMUM">:$op_type);
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data);
let results = (outs AnyTypeOf<[AnyZTensor]>:$output);
let builders = [
OpBuilder<(ins "::mlir::Value":$data, "::mlir::StringAttr":$op_type), [{
OpBuilder<(ins "::mlir::Value":$data), [{
Type elementType = mlir::cast<ShapedType>(data.getType()).getElementType();
UnrankedTensorType resType = UnrankedTensorType::get(elementType);
build($_builder, $_state, resType, data, op_type);
build($_builder, $_state, resType, data);
}]>
];
let extraClassDefinition = [{
Expand All @@ -682,14 +681,13 @@ def ZHighReduceMinOp:ZHigh_Op<"ReduceMin", [Pure, SameOperandsAndResultLayout,
ZHigh operation to perform a ReduceMin.
op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM.
}];
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data,
DefaultValuedStrAttr<StrAttr, "REDUCE_OP_MINIMUM">:$op_type);
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$data);
let results = (outs AnyTypeOf<[AnyZTensor]>:$output);
let builders = [
OpBuilder<(ins "::mlir::Value":$data, "::mlir::StringAttr":$op_type), [{
OpBuilder<(ins "::mlir::Value":$data), [{
Type elementType = mlir::cast<ShapedType>(data.getType()).getElementType();
UnrankedTensorType resType = UnrankedTensorType::get(elementType);
build($_builder, $_state, resType, data, op_type);
build($_builder, $_state, resType, data);
}]>
];
let extraClassDefinition = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,22 @@ LogicalResult ZHighReductionOpShapeHelper<OP_TYPE>::computeShape() {
// Get operand.
Value data = operandAdaptor.getData();

// Get Rank
int64_t rank = createIE->getShapedTypeRank(data);

// Output dims of result.
DimsExpr outputDims;

// Get operands and bounds.
SmallVector<IndexExpr, 4> inputDims;
createIE->getShapeAsDims(data, inputDims);

int64_t axis = rank - 1;
// NNPA only supports reduction over the innermost dimension.
// So set the innermost dimension of the output to one.
int64_t axis = inputDims.size() - 1;
LiteralIndexExpr one(1);
// Copy the input until the second to last dimension
for (int64_t i = 0; i < axis; ++i) {
outputDims.emplace_back(inputDims[i]);
}
// The innermost dimension or last dimension needs to be reduced to one
outputDims.emplace_back(
one); // NNPA is always true for keepdims so we will reduce the dimension
outputDims.emplace_back(one);

// Save the final result.
setOutputDims(outputDims);
Expand Down
6 changes: 2 additions & 4 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ def ZLowReduceMaxOp:ZLow_Op<"reducemax", [MemRefsNormalizable]> {
MemRefOf<[I8]>:$work_area,
MemRefOf<[I64]>:$shape,
ZMemRef:$Out,
StrAttr:$layout,
StrAttr:$op_type);
StrAttr:$layout);
}

def ZLowReduceMinOp:ZLow_Op<"reducemin", [MemRefsNormalizable]> {
Expand All @@ -272,8 +271,7 @@ def ZLowReduceMinOp:ZLow_Op<"reducemin", [MemRefsNormalizable]> {
MemRefOf<[I8]>:$work_area,
MemRefOf<[I64]>:$shape,
ZMemRef:$Out,
StrAttr:$layout,
StrAttr:$op_type);
StrAttr:$layout);
}

def ZLowMatMulOp:ZLow_Op<"matmul", [MemRefsNormalizable,
Expand Down
25 changes: 4 additions & 21 deletions test/accelerators/NNPA/backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -438,31 +438,14 @@ set(NNPA_TEST_LIST_ARCH_15
# ==OP== ReduceMax
# ==LEVEL== arch15
# ==MIN== 1
# ==LIM== - We do no support `do_not_keepdims` backend tests.
# test_reduce_max_bool_inputs_cpu,zdnn_reduce_ext
test_reduce_max_default_axes_keepdim_example_cpu,zdnn_reduce_ext
test_reduce_max_default_axes_keepdims_random_cpu,zdnn_reduce_ext
# test_reduce_max_do_not_keepdims_example_cpu,zdnn_reduce_ext
# test_reduce_max_do_not_keepdims_random_cpu,zdnn_reduce_ext
test_reduce_max_keepdims_example_cpu,zdnn_reduce_ext
test_reduce_max_keepdims_random_cpu,zdnn_reduce_ext
test_reduce_max_negative_axes_keepdims_example_cpu,zdnn_reduce_ext
test_reduce_max_negative_axes_keepdims_random_cpu,zdnn_reduce_ext
# ==LIM== - We do not support `do_not_keepdims` backend tests. Only support reduction over the innermost dimension.
# Currrently, there is no backend test in ONNX that does reduction on the innermost dimension.

# ==OP== ReduceMin
# ==LEVEL== arch15
# ==MIN== 1
# ==LIM== - We do no support `do_not_keepdims` backend tests.
# test_reduce_min_bool_inputs_cpu,zdnn_reduce_ext
test_reduce_min_default_axes_keepdims_example_cpu,zdnn_reduce_ext
test_reduce_min_default_axes_keepdims_random_cpu,zdnn_reduce_ext
# test_reduce_min_do_not_keepdims_example_cpu,zdnn_reduce_ext
# test_reduce_min_do_not_keepdims_random_cpu,zdnn_reduce_ext
test_reduce_min_empty_set_cpu,zdnn_reduce_ext
test_reduce_min_keepdims_example_cpu,zdnn_reduce_ext
test_reduce_min_keepdims_random_cpu,zdnn_reduce_ext
test_reduce_min_negative_axes_keepdims_example_cpu,zdnn_reduce_ext
test_reduce_min_negative_axes_keepdims_random_cpu,zdnn_reduce_ext
# ==LIM== - We do not support `do_not_keepdims` backend tests. Only support reduction over the innermost dimension.
# Currrently, there is no backend test in ONNX that does reduction on the innermost dimension.

# ==OP== Sqrt
# ==LEVEL== arch15
Expand Down
Loading

0 comments on commit 0183ad9

Please sign in to comment.