Skip to content

Commit

Permalink
Code cleanup and some upgrade to builders (#2920)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Aug 28, 2024
1 parent 4f0a141 commit f7d8db5
Show file tree
Hide file tree
Showing 159 changed files with 1,090 additions and 1,080 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Value getLSTMGRUGetYc(
SmallVector<Value, 4> emitONNXSplitOp(Location loc, PatternRewriter &rewriter,
Value input, IntegerAttr axis, ArrayAttr split) {
Type elementType = mlir::cast<ShapedType>(input.getType()).getElementType();
SmallVector<mlir::Type> outputTypes;
SmallVector<Type> outputTypes;
int64_t splitNum = split.size();
ArrayRef<int64_t> inputShape =
mlir::cast<RankedTensorType>(input.getType()).getShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ ValueRange splitAlongAxis(
return splits;
}

bool isF32ScalarConstantTensor(mlir::Value v) {
bool isF32ScalarConstantTensor(Value v) {
if (!isScalarConstantTensor(v))
return false;
auto t = dyn_cast<ShapedType>(v.getType());
auto t = mlir::dyn_cast<ShapedType>(v.getType());
return t.getElementType().isF32();
}

Expand All @@ -93,7 +93,7 @@ Value getDynShape(Location loc, PatternRewriter &rewriter, Value x) {
llvm_unreachable("The input must have shape and rank");

OnnxBuilder create(rewriter, loc);
auto t = dyn_cast<ShapedType>(x.getType());
auto t = mlir::dyn_cast<ShapedType>(x.getType());
int64_t r = t.getRank();
SmallVector<Value> dims;
for (int64_t i = 0; i < r; ++i) {
Expand Down
32 changes: 16 additions & 16 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/PerfModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,44 +416,44 @@ double estimateTimeForUnstickOp(Value oper) {
bool estimateTimeForOpWithModel(Operation *op, const DimAnalysis *dimAnalysis,
double &cpuEstimatedTime, double &nnpaEstimatedTime) {
bool opHasModel = true;
if (auto addOp = dyn_cast<ONNXAddOp>(op))
if (auto addOp = mlir::dyn_cast<ONNXAddOp>(op))
estimateTimeForOp(addOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto divOp = dyn_cast<ONNXDivOp>(op))
else if (auto divOp = mlir::dyn_cast<ONNXDivOp>(op))
estimateTimeForOp(divOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto maxOp = dyn_cast<ONNXMaxOp>(op))
else if (auto maxOp = mlir::dyn_cast<ONNXMaxOp>(op))
estimateTimeForOp(maxOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto minOp = dyn_cast<ONNXMinOp>(op))
else if (auto minOp = mlir::dyn_cast<ONNXMinOp>(op))
estimateTimeForOp(minOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto mulOp = dyn_cast<ONNXMulOp>(op))
else if (auto mulOp = mlir::dyn_cast<ONNXMulOp>(op))
estimateTimeForOp(mulOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto powOp = dyn_cast<ONNXPowOp>(op))
else if (auto powOp = mlir::dyn_cast<ONNXPowOp>(op))
estimateTimeForOp(powOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto subOp = dyn_cast<ONNXSubOp>(op))
else if (auto subOp = mlir::dyn_cast<ONNXSubOp>(op))
estimateTimeForOp(subOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
// Unary elementwise NNPA candidate ops.
else if (auto expOp = dyn_cast<ONNXExpOp>(op))
else if (auto expOp = mlir::dyn_cast<ONNXExpOp>(op))
estimateTimeForOp(expOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto logOp = dyn_cast<ONNXLogOp>(op))
else if (auto logOp = mlir::dyn_cast<ONNXLogOp>(op))
estimateTimeForOp(logOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto reluOp = dyn_cast<ONNXReluOp>(op))
else if (auto reluOp = mlir::dyn_cast<ONNXReluOp>(op))
estimateTimeForOp(reluOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto sigmoidOp = dyn_cast<ONNXSigmoidOp>(op))
else if (auto sigmoidOp = mlir::dyn_cast<ONNXSigmoidOp>(op))
estimateTimeForOp(
sigmoidOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto softmaxOp = dyn_cast<ONNXSoftmaxOp>(op))
else if (auto softmaxOp = mlir::dyn_cast<ONNXSoftmaxOp>(op))
estimateTimeForOp(
softmaxOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto tanhOp = dyn_cast<ONNXTanhOp>(op))
else if (auto tanhOp = mlir::dyn_cast<ONNXTanhOp>(op))
estimateTimeForOp(tanhOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
// Reduce
else if (auto reduceMeanOp = dyn_cast<ONNXReduceMeanV13Op>(op))
else if (auto reduceMeanOp = mlir::dyn_cast<ONNXReduceMeanV13Op>(op))
estimateTimeForOp(
reduceMeanOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
// Matmul.
else if (auto matMulOp = dyn_cast<ONNXMatMulOp>(op))
else if (auto matMulOp = mlir::dyn_cast<ONNXMatMulOp>(op))
estimateTimeForOp(
matMulOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else if (auto gemmOp = dyn_cast<ONNXGemmOp>(op))
else if (auto gemmOp = mlir::dyn_cast<ONNXGemmOp>(op))
estimateTimeForOp(gemmOp, dimAnalysis, cpuEstimatedTime, nnpaEstimatedTime);
else
opHasModel = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ bool canInferencePadsForNNPAConv(ONNXConvOp op) {
// Create an ArrayAttr of IntegerAttr(s) of zero values.
// This function is used for padding attribute in Conv.
ArrayAttr getPadsForNNPAConv(PatternRewriter &rewriter, Value ret) {
ONNXConvOp op = dyn_cast<ONNXConvOp>(ret.getDefiningOp());
ONNXConvOp op = mlir::dyn_cast<ONNXConvOp>(ret.getDefiningOp());
ONNXConvOpShapeHelper shapeHelper(op.getOperation(), {});
shapeHelper.computeShapeAndAssertOnFailure();
SmallVector<int64_t, 4> vals;
Expand Down Expand Up @@ -451,7 +451,7 @@ class AddSubWithRHSZeroExpandPattern : public OpRewritePattern<OP_TYPE> {
if (isa<BlockArgument>(B))
return false;
bool BIsZero = false;
if (auto expandOp = dyn_cast<ONNXExpandOp>(B.getDefiningOp())) {
if (auto expandOp = mlir::dyn_cast<ONNXExpandOp>(B.getDefiningOp())) {
Value input = expandOp.getInput();
if (isDenseONNXConstant(input)) {
// Expand's input is 0?
Expand Down
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ struct ZHighToZLowStickOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickOp stickOp = cast<ZHighStickOp>(op);
ZHighStickOp stickOp = mlir::cast<ZHighStickOp>(op);

ZHighStickOpAdaptor operandAdaptor(operands);
Value input = operandAdaptor.getIn();
Expand Down Expand Up @@ -1557,7 +1557,7 @@ struct ZHighToZLowStickifiedConstantOfShapeOpLowering
Location loc = op->getLoc();
MDBuilder create(rewriter, loc);

auto stickOp = cast<ZHighStickifiedConstantOfShapeOp>(op);
auto stickOp = mlir::cast<ZHighStickifiedConstantOfShapeOp>(op);
FloatAttr value = stickOp.getValueAttr();
Type i16Ty = rewriter.getI16Type();
Type i64Ty = rewriter.getI64Type();
Expand Down
28 changes: 14 additions & 14 deletions src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class ZLowStickLowering : public mlir::ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowStickOp stickOp = cast<ZLowStickOp>(op);
ZLowStickOp stickOp = mlir::cast<ZLowStickOp>(op);

ZLowStickOpAdaptor operandAdaptor(operands);
// Do not get element type from adaptor since the type can be opaque.
Expand Down Expand Up @@ -154,7 +154,7 @@ class ZLowStickForLSTMLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowStickForLSTMOp stickForLSTMOp = cast<ZLowStickForLSTMOp>(op);
ZLowStickForLSTMOp stickForLSTMOp = mlir::cast<ZLowStickForLSTMOp>(op);

ZLowStickForLSTMOpAdaptor operandAdaptor(operands);
Type llvmElementTy = typeConverter->convertType(
Expand Down Expand Up @@ -240,7 +240,7 @@ class ZLowStickForGRULowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowStickForGRUOp stickForGRUOp = cast<ZLowStickForGRUOp>(op);
ZLowStickForGRUOp stickForGRUOp = mlir::cast<ZLowStickForGRUOp>(op);

ZLowStickForGRUOpAdaptor operandAdaptor(operands);
Type llvmElementTy = typeConverter->convertType(
Expand Down Expand Up @@ -324,7 +324,7 @@ class ZLowLSTMLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowLSTMOp lstmOp = cast<ZLowLSTMOp>(op);
ZLowLSTMOp lstmOp = mlir::cast<ZLowLSTMOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

ZLowLSTMOpAdaptor operandAdaptor(operands);
Expand Down Expand Up @@ -520,7 +520,7 @@ class ZLowGRULowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowGRUOp gruOp = cast<ZLowGRUOp>(op);
ZLowGRUOp gruOp = mlir::cast<ZLowGRUOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

ZLowGRUOpAdaptor operandAdaptor(operands);
Expand Down Expand Up @@ -675,7 +675,7 @@ class ZLowUnstickLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowUnstickOp unstickOp = cast<ZLowUnstickOp>(op);
ZLowUnstickOp unstickOp = mlir::cast<ZLowUnstickOp>(op);

ZLowUnstickOpAdaptor operandAdaptor(operands);
Type llvmElementTy = typeConverter->convertType(
Expand Down Expand Up @@ -732,7 +732,7 @@ class ZLowUnaryElementwiseOpLowering : public ConvertToLLVMPattern {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
UnaryElementwiseOp unaryOp = cast<UnaryElementwiseOp>(op);
UnaryElementwiseOp unaryOp = mlir::cast<UnaryElementwiseOp>(op);
typename UnaryElementwiseOp::Adaptor operandAdaptor(operands);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

Expand Down Expand Up @@ -810,7 +810,7 @@ class ZLowBinaryElementwiseOpLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
BinaryElementwiseOp binaryOp = cast<BinaryElementwiseOp>(op);
BinaryElementwiseOp binaryOp = mlir::cast<BinaryElementwiseOp>(op);
typename BinaryElementwiseOp::Adaptor operandAdaptor(operands);

Value input1 = operandAdaptor.getX();
Expand Down Expand Up @@ -888,7 +888,7 @@ class ZLowSoftmaxOpLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowSoftmaxOp softmaxOp = cast<ZLowSoftmaxOp>(op);
ZLowSoftmaxOp softmaxOp = mlir::cast<ZLowSoftmaxOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

ZLowSoftmaxOpAdaptor operandAdaptor(operands);
Expand Down Expand Up @@ -971,7 +971,7 @@ class ZLowMatMulLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowMatMulOp matmulOp = cast<ZLowMatMulOp>(op);
ZLowMatMulOp matmulOp = mlir::cast<ZLowMatMulOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

ZLowMatMulOpAdaptor operandAdaptor(operands);
Expand Down Expand Up @@ -1109,7 +1109,7 @@ class ZLowConv2DLowering : public ConvertToLLVMPattern {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
ZLowConv2DOp convOp = cast<ZLowConv2DOp>(op);
ZLowConv2DOp convOp = mlir::cast<ZLowConv2DOp>(op);
ZLowConv2DOpAdaptor operandAdaptor(operands);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

Expand Down Expand Up @@ -1256,7 +1256,7 @@ class ZLowPool2DLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
POOLOP poolOp = cast<POOLOP>(op);
POOLOP poolOp = mlir::cast<POOLOP>(op);
typename POOLOP::Adaptor operandAdaptor(operands);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

Expand Down Expand Up @@ -1360,7 +1360,7 @@ class ZLowMeanReduce2DLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowMeanReduce2DOp meanOp = cast<ZLowMeanReduce2DOp>(op);
ZLowMeanReduce2DOp meanOp = mlir::cast<ZLowMeanReduce2DOp>(op);
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

ZLowMeanReduce2DOpAdaptor operandAdaptor(operands);
Expand Down Expand Up @@ -1429,7 +1429,7 @@ class ZLowBatchNormLowering : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const override {
ModuleOp module = op->getParentOfType<ModuleOp>();
Location loc = op->getLoc();
ZLowBatchNormOp batchnormOp = cast<ZLowBatchNormOp>(op);
ZLowBatchNormOp batchnormOp = mlir::cast<ZLowBatchNormOp>(op);

ZLowBatchNormOpAdaptor operandAdaptor(operands);
Type llvmElementTy = typeConverter->convertType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ std::vector<Value> getDimsFromShapeMemRefBySize(PatternRewriter &rewriter,
bitcastOp.getArg().getDefiningOp());
if (addressOfOp) {
LLVM::GlobalOp globalOp =
dyn_cast_or_null<LLVM::GlobalOp>(SymbolTable::lookupSymbolIn(
mlir::dyn_cast_or_null<LLVM::GlobalOp>(SymbolTable::lookupSymbolIn(
module, addressOfOp.getGlobalNameAttr()));
if (globalOp) {
DenseElementsAttr valueAttr =
Expand Down
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultLayout(Operation *op) {
namespace onnx_mlir {
namespace zhigh {

std::vector<mlir::Type> getZHighAuxSplitResultType(
std::vector<Type> getZHighAuxSplitResultType(
Value input, int64_t axis, ArrayAttr split) {
Type elementType = mlir::cast<ShapedType>(input.getType()).getElementType();
std::vector<mlir::Type> outputTypes;
std::vector<Type> outputTypes;
if (split.size() == 0) {
llvm_unreachable("Unsupported split (size==0)");
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace zhigh {
//===----------------------------------------------------------------------===//

LogicalResult ZHighBatchNormOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
std::function<void(Region &)> doShapeInference) {
return inferShapeForUnaryOps(this->getOperation());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ LogicalResult ZHighConv2DOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult ZHighConv2DOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
std::function<void(Region &)> doShapeInference) {
if (!hasRankedType(getInput()) || !hasRankedType(getInputKernel()))
return success();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void ZHighDLF16ToF32Op::build(
Type elementType = builder.getF32Type();
Type resType = UnrankedTensorType::get(elementType);

if (auto inType = dyn_cast<RankedTensorType>(input.getType()))
if (auto inType = mlir::dyn_cast<RankedTensorType>(input.getType()))
resType = RankedTensorType::get(inType.getShape(), elementType);

build(builder, state, resType, input);
Expand All @@ -44,7 +44,7 @@ void ZHighDLF16ToF32Op::build(
//===----------------------------------------------------------------------===//

LogicalResult ZHighDLF16ToF32Op::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
std::function<void(Region &)> doShapeInference) {
return inferShapeForUnaryOps(this->getOperation());
}

Expand Down
Loading

0 comments on commit f7d8db5

Please sign in to comment.