Skip to content

Commit

Permalink
Use DisposableElementsAttr for ZHigh constant propagation (#3013)
Browse files Browse the repository at this point in the history
* Revert "[NNPA] Memory reduction of stickified constant by stickifying at file writing  (#2917)"

This reverts commit 33b466e.

Signed-off-by: Tung D. Le <[email protected]>

* Use DisposableElementsAttr for ZHigh Constant Propagation

Signed-off-by: Tung D. Le <[email protected]>


---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Dec 4, 2024
1 parent 40f5017 commit f3fec68
Show file tree
Hide file tree
Showing 50 changed files with 781 additions and 849 deletions.
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_onnx_mlir_library(OMNNPACompilerOptions

add_onnx_mlir_library(OMNNPACompilerUtils
NNPACompilerUtils.cpp
ZHighDisposableGarbageCollector.cpp

EXCLUDE_FROM_OM_LIBS

Expand Down
14 changes: 8 additions & 6 deletions src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerUtils.hpp"
#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
Expand Down Expand Up @@ -120,10 +121,6 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
}

// Replace every DisposableElementsAttr with DenseElementsAttr.
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
pm.addPass(createScrubDisposablePass());

// Experimental feature: Decompose stick/unstick into two phases: layout
// transform and data conversion. Do some optimizations after decomposing.
// Then, recompose again layout and data conversion if they are not optimized.
Expand All @@ -146,15 +143,17 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
// Only support BE machines.
bool isBE = llvm::endianness::native == llvm::endianness::big;
if (isBE)
pm.addNestedPass<func::FuncOp>(
onnx_mlir::zhigh::createZHighConstPropagationPass());
pm.addPass(onnx_mlir::zhigh::createZHighConstPropagationPass());

// Remove common sub-expressions.
pm.addPass(mlir::createCSEPass());

// Clean dead code.
pm.addPass(mlir::createSymbolDCEPass());

// Replace every DisposableElementsAttr with DenseElementsAttr.
pm.addPass(onnx_mlir::zhigh::createZHighScrubDisposablePass());

// Insert an instrumentation after lowering onnx to zhigh to get profiling
// for onnx and zhigh ops.
// Keep this pass at the end of this function.
Expand Down Expand Up @@ -195,6 +194,9 @@ void addPassesNNPA(mlir::OwningOpRef<mlir::ModuleOp> &module,

// LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;);
if (emissionTarget >= EmitONNXIR) {
pm.addInstrumentation(
std::make_unique<onnx_mlir::zhigh::ZHighDisposableGarbageCollector>(
pm.getContext()));
addONNXToMLIRPasses(pm, /*target CPU*/ maccel.empty(),
/*donotScrubDisposableElementsAttr*/ true);
pm.addPass(onnx_mlir::createDevicePlacementPass(nnpaLoadDevicePlacementFile,
Expand Down
43 changes: 43 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.cpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp"
#include "src/Dialect/ONNX/ONNXDialect.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

#include "mlir/IR/BuiltinOps.h"

using namespace mlir;

namespace onnx_mlir {
namespace zhigh {

ZHighDisposableGarbageCollector::ZHighDisposableGarbageCollector(
MLIRContext *context)
: disposablePool(*DisposablePool::get<ONNXDialect>(context)) {}

ZHighDisposableGarbageCollector::~ZHighDisposableGarbageCollector() {}

void ZHighDisposableGarbageCollector::runAfterPass(Pass *pass, Operation *op) {
if (!disposablePool.isActive())
return;
ModuleOp moduleOp = mlir::dyn_cast<ModuleOp>(op);
if (!moduleOp)
return;
disposablePool.garbageCollectUnreachable(
moduleOp, {{ONNXConstantOp::getOperationName(), "value"},
{ONNXConstantOfShapeOp::getOperationName(), "value"},
{ZHighStickifiedConstantOp::getOperationName(), "value"}});
}

} // namespace zhigh
} // namespace onnx_mlir
37 changes: 37 additions & 0 deletions src/Accelerators/NNPA/Compiler/ZHighDisposableGarbageCollector.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===---------------- ZHighDisposableGarbageCollector.hpp -----------------===//
//
// Garbage collects DisposableElementsAttr attributes.
//
//===----------------------------------------------------------------------===//

#ifndef ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H
#define ONNX_MLIR_ZHIGH_GARBAGE_COLLECTOR_H

#include "mlir/Pass/PassInstrumentation.h"

namespace mlir {
class MLIRContext;
}

namespace onnx_mlir {
class DisposablePool;

namespace zhigh {

struct ZHighDisposableGarbageCollector : public mlir::PassInstrumentation {
ZHighDisposableGarbageCollector(mlir::MLIRContext *context);
~ZHighDisposableGarbageCollector() override;

void runAfterPass(mlir::Pass *pass, mlir::Operation *op) override;

private:
DisposablePool &disposablePool;
};

} // namespace zhigh
} // namespace onnx_mlir
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ bool isF32ScalarConstantTensor(Value v) {
FloatAttr getScalarF32AttrFromConstant(Value v) {
if (!isF32ScalarConstantTensor(v))
return nullptr;
DenseElementsAttr constElements = ElementsAttrBuilder::toDenseElementsAttr(
getElementAttributeFromONNXValue(v));
ElementsAttr constElements = getElementAttributeFromONNXValue(v);
return constElements.getSplatValue<FloatAttr>();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

//===---------- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh
//---------===//
//===---- ONNXToZHighCommon.hpp - Common functions in ONNXToZHigh ---------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
Expand Down Expand Up @@ -117,4 +116,4 @@ mlir::Value getDynShape(
mlir::Location loc, mlir::PatternRewriter &rewriter, mlir::Value x);

} // namespace onnx_mlir
#endif
#endif
159 changes: 71 additions & 88 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ static Value insertAllocForWorkAreaForRNNOps(IndexExprBuilderForKrnl &createIE,
return create.mem.alignedAlloc(resultType, dims, gAlignment);
}

/// Get a dense resource attribute to store stickified data of a given i8 value.
/// Attribute type: tensor<sizeInBytes x i8>
DenseResourceElementsAttr getDenseResourceElementsAttrOfValue(
PatternRewriter &rewriter, ZHighStickifiedConstantOp stickifiedConstant,
int8_t val, int64_t sizeInBytes) {
char *rawData = static_cast<char *>(malloc(sizeInBytes));
assert(rawData && "failed to allocate memory for stickified data");
memset(rawData, val, sizeInBytes);
DenseResourceElementsAttr valueAttr = DenseUI8ResourceElementsAttr::get(
RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
stickifiedConstant.getOperation()
->getDialect()
->getNamespace(), // use the dialect as the blob "hint"
HeapAsmResourceBlob::allocateAndCopyWithAlign(
llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
free(rawData);
return valueAttr;
}

/// This function emits a buffer of zero elements for the given dimensions and
/// layout. If the given dimensions are static, then a stickified constant is
/// returned.
Expand All @@ -190,48 +209,18 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Set zero in value attribute as DenseResourceElementsAttribute.
// ZHighStickifiedConstantOp stickifiedConstant =
// rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
// /*stickified=*/rewriter.getBoolAttr(true),
// /*value=*/nullptr,
// /*alignment=*/rewriter.getI64IntegerAttr(4096));
//
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = static_cast<char *>(malloc(sizeInBytes));
// assert(rawData && "failed to allocate memory for stickified data");
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);
// #else

// Set zero in value attribute as SplatElementsAttr.
FloatAttr floatZero = rewriter.getFloatAttr(resType.getElementType(), 0.0);
ZHighStickifiedConstantOp stickifiedConstant = rewriter.create<
ZHighStickifiedConstantOp>(loc, resType,
/*stickified=*/rewriter.getBoolAttr(true),
/*value=*/SplatElementsAttr::get(cast<ShapedType>(resType), floatZero),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN

ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
// Attribute type: tensor<sizeInBytes x i8>
int64_t sizeInBytes =
affine::getIntOrFloatMemRefSizeInBytes(resType).value();
DenseResourceElementsAttr valueAttr = getDenseResourceElementsAttrOfValue(
rewriter, stickifiedConstant, 0, sizeInBytes);
stickifiedConstant.setValueAttr(valueAttr);
res = stickifiedConstant.getResult();
} else {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Expand Down Expand Up @@ -706,7 +695,7 @@ struct ZHighToZLowUnstickOpLowering : public ConversionPattern {
};

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to ZLow Stickified Constant
// Lower ZHigh Stickified Constant to KrnlGlobal
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
Expand All @@ -719,7 +708,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp zhighStickifiedConstOp =
ZHighStickifiedConstantOp stickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -733,59 +722,53 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Create ZLowStickifiedConstantOp.
StringAttr layout =
getZTensorLayoutAttr(rewriter, *op->result_type_begin());
// Validate the stickified tensor.
Attribute valueAttr = stickifiedConstOp.getValueAttr();
int64_t sizeInBytes = getMemRefEltSizeInBytes(normalizedType);
sizeInBytes *= normalizedType.getNumElements();
if (auto denseAttr = mlir::dyn_cast_or_null<DenseElementsAttr>(valueAttr)) {
ArrayRef<char> data = denseAttr.getRawData();
if (denseAttr.isSplat()) {
// Constant ztensor's buffer is tensor<sizeInBytes x i8>.
int8_t v = denseAttr.getSplatValue<int8_t>();
// NNPA does not work with a splat buffer.
// Expand the memory buffer for NNPA by using DenseResourceElementsAttr.
valueAttr = getDenseResourceElementsAttrOfValue(
rewriter, stickifiedConstOp, v, sizeInBytes);
} else {
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
}
} else if (auto resourceAttr =
mlir::dyn_cast_or_null<DenseResourceElementsAttr>(
valueAttr)) {
auto blob = resourceAttr.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();
assert(
(data.size() == static_cast<uint64_t>(sizeInBytes)) &&
"The stickified tensor's buffer size and MemRef's size mismatched");
} else {
llvm_unreachable("Unsupported ElementsAttr");
}

// Keep previous implementation about generating stickified data at
// ZHighConstPropagationPass. To use this, comment in and set directive "
// NNPA_ZHIGH_STICKIFIEDCONST_GEN"
//
// #ifdef NNPA_ZHIGH_STICKIFIEDCONST_GEN
// // Lower to KrnlGlobalOp
// // Get dense resource attribute.
// auto blob = mlir::cast<DenseResourceElementsAttr>(
// zhighStickifiedConstOp.getValue().value())
// .getRawHandle()
// .getBlob();
// assert(blob && "Expecting dense resource with a valid blob");
// ArrayRef<char> data = blob->getData();
// // Validate the stickified tensor.
// int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
// memRefSizeInBytes *= normalizedType.getNumElements();
// assert((data.size() == static_cast<uint64_t>(memRefSizeInBytes)) &&
// "The stickified tensor's buffer size and MemRef's size
// mismatched");
// // Create a KrnlGlobalOp.
// KrnlGlobalOp constantOp =
// rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
// /*shape=*/
// rewriter.getI64ArrayAttr(normalizedShape),
// /*name=*/
// rewriter.getStringAttr(
// "constant_stickify_" + std::to_string(constantID)),
// /*value=*/zhighStickifiedConstOp.getValueAttr(),
// /*offset=*/nullptr,
// /*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #else
ZLowStickifiedConstantOp constantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*stickified=*/zhighStickifiedConstOp.getStickifiedAttr(),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*layout=*/layout,
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
// #endif // NNPA_ZHIGH_STICKIFIEDCONST_GEN
/*value=*/valueAttr,
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());

// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantOp.getResult());
rewriter.replaceOp(op, constantGlobal.getResult());
return success();
}
};
Expand Down
1 change: 0 additions & 1 deletion src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ add_onnx_mlir_library(OMZHighOps
OMONNXOps # Use ONNXShapeHelper
OMLayoutHelper
OMShapeHelperOpInterface
OMStickify
OMNNPACompilerOptions
MLIRIR

Expand Down
5 changes: 1 addition & 4 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -862,14 +862,11 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let summary = "ZHigh Stickified Constant operation";
let description = [{
This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
}];
let arguments = (ins BoolAttr:$stickified,
OptionalAttr<AnyAttr>:$value,
let arguments = (ins OptionalAttr<AnyAttr>:$value,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
Loading

0 comments on commit f3fec68

Please sign in to comment.