Skip to content

Commit

Permalink
PDLL article (#43)
Browse files Browse the repository at this point in the history
* update llvm commit hash to 2024-07-12

* update implementation for MLIR version bump

* add dedicated test file

* add first attempt at porting MulToAdd to PDLL

* add LHS/RHS versions of power of two pattern

* add PeelFromMul

* add Cmake build

---------

Co-authored-by: Jeremy Kun <[email protected]>
  • Loading branch information
j2kun and j2kun authored Jul 31, 2024
1 parent 8ac609a commit e0c9f52
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 11 deletions.
4 changes: 2 additions & 2 deletions bazel/import_llvm.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ load(
def import_llvm(name):
"""Imports LLVM."""

# 2023-11-13
LLVM_COMMIT = "f778eafdd878e8b11ad76f9e0a312ce7791a7481"
# 2024-07-12
LLVM_COMMIT = "0913547d0e3939cc420e88ecd037240f33736820"

new_git_repository(
name = name,
Expand Down
2 changes: 1 addition & 1 deletion externals/llvm-project
6 changes: 3 additions & 3 deletions lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ReduceNoiseAnalysis::ReduceNoiseAnalysis(Operation *op) {
op->walk([&](Operation *op) {
// FIXME: assumes all reduce_noise ops have already been removed and their
// values forwarded.
if (!llvm::isa<noisy::AddOp, noisy::SubOp, noisy::MulOp>(op)) {
if (!isa<noisy::AddOp, noisy::SubOp, noisy::MulOp>(op)) {
return;
}

Expand Down Expand Up @@ -81,8 +81,8 @@ ReduceNoiseAnalysis::ReduceNoiseAnalysis(Operation *op) {
// In the tutorial, there is no control flow, so these are the function
// arguments of the main function being analyzed. A real compiler would
// need to handle this more generically.
if (value.isa<BlockArgument>() ||
llvm::isa<noisy::EncodeOp>(value.getDefiningOp())) {
if (isa<BlockArgument>(value) ||
isa<noisy::EncodeOp>(value.getDefiningOp())) {
MPConstraint *const ct =
solver->MakeRowConstraint(INITIAL_NOISE, INITIAL_NOISE, "");
ct->SetCoefficient(var, 1);
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Noisy/NoisyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class Noisy_BinOp<string mnemonic> : Op<Noisy_Dialect, mnemonic, [
Pure,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<InferIntRangeInterface>
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]> {
let arguments = (ins Noisy_I32:$lhs, Noisy_I32:$rhs);
let results = (outs Noisy_I32:$output);
Expand All @@ -33,7 +33,7 @@ def Noisy_MulOp : Noisy_BinOp<"mul"> {
}

def Noisy_EncodeOp : Op<Noisy_Dialect, "encode", [
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Encodes a noisy i32 from a small-width integer, injecting 12 bits of noise.";
let arguments = (ins AnyIntOfWidths<[1, 2, 3, 4, 5]>:$input);
let results = (outs Noisy_I32:$output);
Expand All @@ -48,7 +48,7 @@ def Noisy_DecodeOp : Op<Noisy_Dialect, "decode", [Pure]> {
}

def Noisy_ReduceNoiseOp : Op<Noisy_Dialect, "reduce_noise", [
Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Reduces the noise in a noisy integer to a fixed noise level. Expensive!";
let arguments = (ins Noisy_I32:$input);
let results = (outs Noisy_I32:$output);
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Poly/PolyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {

if (!lhs || !rhs) return nullptr;

auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
auto degree = llvm::cast<PolynomialType>(getResult().getType()).getDegreeBound();
auto maxIndex = lhs.size() + rhs.size() - 1;

SmallVector<APInt, 8> result;
Expand Down Expand Up @@ -68,7 +68,7 @@ OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
LogicalResult EvalOp::verify() {
auto pointTy = getPoint().getType();
bool isSignlessInteger = pointTy.isSignlessInteger(32);
auto complexPt = llvm::dyn_cast<ComplexType>(pointTy);
auto complexPt = dyn_cast<ComplexType>(pointTy);
return isSignlessInteger || complexPt ? success()
: emitOpError(
"argument point must be a 32-bit "
Expand Down
35 changes: 35 additions & 0 deletions lib/Transform/Arith/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ gentbl_cc_library(
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PDLDialectTdFiles",
"@llvm-project//mlir:PDLInterpOpsTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
Expand All @@ -47,6 +49,39 @@ cc_library(
hdrs = ["Passes.h"],
deps = [
":MulToAdd",
":MulToAddPdll",
":pass_inc_gen",
],
)

gentbl_cc_library(
name = "MulToAddPdllIncGen",
tbl_outs = [
(
["-x=cpp"],
"MulToAddPdll.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-pdll",
td_file = "MulToAdd.pdll",
deps = [
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:ArithOpsTdFiles",
],
)

cc_library(
name = "MulToAddPdll",
srcs = ["MulToAddPdll.cpp"],
hdrs = ["MulToAddPdll.h"],
deps = [
":pass_inc_gen",
":MulToAddPdllIncGen",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

7 changes: 7 additions & 0 deletions lib/Transform/Arith/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
add_mlir_pdll_library(MulToAddPdllIncGen
MulToAdd.pdll
MulToAddPdll.h.inc
)

add_mlir_library(MulToAdd
MulToAdd.cpp
MulToAddPdll.cpp

${PROJECT_SOURCE_DIR}/lib/Transform/Arith/
ADDITIONAL_HEADER_DIRS

DEPENDS
MLIRMulToAddPasses
MulToAddPdllIncGen

LINK_LIBS PUBLIC
)
Expand Down
67 changes: 67 additions & 0 deletions lib/Transform/Arith/MulToAdd.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "mlir/Dialect/Arith/IR/ArithOps.td"

Constraint IsPowerOfTwo(attr: Attr) [{
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue();
return success((value & (value - 1)) == 0);
}];

// Currently, constraints that return values must be defined in C++
Constraint Halve(atttr: Attr) -> Attr;
Constraint MinusOne(attr: Attr) -> Attr;

// Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do
// nothing.
Pattern PowerOfTwoExpandRhs with benefit(2) {
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);
IsPowerOfTwo(const);
let halved: Attr = Halve(const);

rewrite root with {
let newConst = op<arith.constant> {value = halved};
let newMul = op<arith.muli>(newConst, rhs);
let newAdd = op<arith.addi>(newMul, newMul);
replace root with newAdd;
};
}

Pattern PowerOfTwoExpandLhs with benefit(2) {
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr});
IsPowerOfTwo(const);
let halved: Attr = Halve(const);

rewrite root with {
let newConst = op<arith.constant> {value = halved};
let newMul = op<arith.muli>(lhs, newConst);
let newAdd = op<arith.addi>(newMul, newMul);
replace root with newAdd;
};
}

// Replace y = 9*x with y = 8*x + x
Pattern PeelFromMulRhs with benefit(1) {
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr});

// We are guaranteed `value` is not a power of two, because the greedy
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
// it has higher benefit.
let minusOne: Attr = MinusOne(const);

rewrite root with {
let newConst = op<arith.constant> {value = minusOne};
let newMul = op<arith.muli>(lhs, newConst);
let newAdd = op<arith.addi>(newMul, lhs);
replace root with newAdd;
};
}

Pattern PeelFromMulLhs with benefit(1) {
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);
let minusOne: Attr = MinusOne(const);

rewrite root with {
let newConst = op<arith.constant> {value = minusOne};
let newMul = op<arith.muli>(newConst, rhs);
let newAdd = op<arith.addi>(newMul, rhs);
replace root with newAdd;
};
}
49 changes: 49 additions & 0 deletions lib/Transform/Arith/MulToAddPdll.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "lib/Transform/Arith/MulToAddPdll.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/include/mlir/Pass/Pass.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DEF_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"

LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2));
return success();
}

LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1));
return success();
}

void registerNativeConstraints(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl);
patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl);
}

struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {
using MulToAddPdllBase::MulToAddPdllBase;

void runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
populateGeneratedPDLLPatterns(patterns);
registerNativeConstraints(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace tutorial
} // namespace mlir
20 changes: 20 additions & 0 deletions lib/Transform/Arith/MulToAddPdll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
#define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_

#include "mlir/Pass/Pass.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Parser/Parser.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DECL_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"

#include "lib/Transform/Arith/MulToAddPdll.h.inc"

} // namespace tutorial
} // namespace mlir

#endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
1 change: 1 addition & 0 deletions lib/Transform/Arith/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_TRANSFORM_ARITH_PASSES_H_

#include "lib/Transform/Arith/MulToAdd.h"
#include "lib/Transform/Arith/MulToAddPdll.h"

namespace mlir {
namespace tutorial {
Expand Down
13 changes: 13 additions & 0 deletions lib/Transform/Arith/Passes.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef LIB_TRANSFORM_ARITH_PASSES_TD_
#define LIB_TRANSFORM_ARITH_PASSES_TD_

include "mlir/Dialect/PDL/IR/PDLDialect.td"
include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td"
include "mlir/Pass/PassBase.td"

def MulToAdd : Pass<"mul-to-add"> {
Expand All @@ -10,4 +12,15 @@ def MulToAdd : Pass<"mul-to-add"> {
}];
}

def MulToAddPdll : Pass<"mul-to-add-pdll"> {
let summary = "Convert multiplications to repeated additions using pdll";
let description = [{
Convert multiplications to repeated additions (using pdll).
}];
let dependentDialects = [
"mlir::pdl::PDLDialect",
"mlir::pdl_interp::PDLInterpDialect",
];
}

#endif // LIB_TRANSFORM_ARITH_PASSES_TD_
33 changes: 33 additions & 0 deletions tests/mul_to_add_pdll.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: tutorial-opt %s --mul-to-add-pdll | FileCheck %s

func.func @just_power_of_two(%arg: i32) -> i32 {
%0 = arith.constant 8 : i32
%1 = arith.muli %arg, %0 : i32
func.return %1 : i32
}

// CHECK-LABEL: func.func @just_power_of_two(
// CHECK-SAME: %[[ARG:.*]]: i32
// CHECK-SAME: ) -> i32 {
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]]
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]]
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]]
// CHECK: return %[[SUM_2]] : i32
// CHECK: }


func.func @power_of_two_plus_one(%arg: i32) -> i32 {
%0 = arith.constant 9 : i32
%1 = arith.muli %arg, %0 : i32
func.return %1 : i32
}

// CHECK-LABEL: func.func @power_of_two_plus_one(
// CHECK-SAME: %[[ARG:.*]]: i32
// CHECK-SAME: ) -> i32 {
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]]
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]]
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]]
// CHECK: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[ARG]]
// CHECK: return %[[SUM_3]] : i32
// CHECK: }

0 comments on commit e0c9f52

Please sign in to comment.