-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update llvm commit hash to 2024-07-12 * update implementation for MLIR version bump * add dedicated test file * add first attempt at porting MulToAdd to PDLL * add LHS/RHS versions of power of two pattern * add PeelFromMul * add Cmake build --------- Co-authored-by: Jeremy Kun <[email protected]>
- Loading branch information
Showing
13 changed files
with
236 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule llvm-project
updated
from f778ea to 091354
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "mlir/Dialect/Arith/IR/ArithOps.td" | ||
|
||
Constraint IsPowerOfTwo(attr: Attr) [{ | ||
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue(); | ||
return success((value & (value - 1)) == 0); | ||
}]; | ||
|
||
// Currently, constraints that return values must be defined in C++ | ||
Constraint Halve(atttr: Attr) -> Attr; | ||
Constraint MinusOne(attr: Attr) -> Attr; | ||
|
||
// Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do | ||
// nothing. | ||
Pattern PowerOfTwoExpandRhs with benefit(2) { | ||
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value); | ||
IsPowerOfTwo(const); | ||
let halved: Attr = Halve(const); | ||
|
||
rewrite root with { | ||
let newConst = op<arith.constant> {value = halved}; | ||
let newMul = op<arith.muli>(newConst, rhs); | ||
let newAdd = op<arith.addi>(newMul, newMul); | ||
replace root with newAdd; | ||
}; | ||
} | ||
|
||
Pattern PowerOfTwoExpandLhs with benefit(2) { | ||
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr}); | ||
IsPowerOfTwo(const); | ||
let halved: Attr = Halve(const); | ||
|
||
rewrite root with { | ||
let newConst = op<arith.constant> {value = halved}; | ||
let newMul = op<arith.muli>(lhs, newConst); | ||
let newAdd = op<arith.addi>(newMul, newMul); | ||
replace root with newAdd; | ||
}; | ||
} | ||
|
||
// Replace y = 9*x with y = 8*x + x | ||
Pattern PeelFromMulRhs with benefit(1) { | ||
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr}); | ||
|
||
// We are guaranteed `value` is not a power of two, because the greedy | ||
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since | ||
// it has higher benefit. | ||
let minusOne: Attr = MinusOne(const); | ||
|
||
rewrite root with { | ||
let newConst = op<arith.constant> {value = minusOne}; | ||
let newMul = op<arith.muli>(lhs, newConst); | ||
let newAdd = op<arith.addi>(newMul, lhs); | ||
replace root with newAdd; | ||
}; | ||
} | ||
|
||
Pattern PeelFromMulLhs with benefit(1) { | ||
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value); | ||
let minusOne: Attr = MinusOne(const); | ||
|
||
rewrite root with { | ||
let newConst = op<arith.constant> {value = minusOne}; | ||
let newMul = op<arith.muli>(newConst, rhs); | ||
let newAdd = op<arith.addi>(newMul, rhs); | ||
replace root with newAdd; | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#include "lib/Transform/Arith/MulToAddPdll.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "mlir/include/mlir/Pass/Pass.h" | ||
|
||
namespace mlir { | ||
namespace tutorial { | ||
|
||
#define GEN_PASS_DEF_MULTOADDPDLL | ||
#include "lib/Transform/Arith/Passes.h.inc" | ||
|
||
LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, | ||
ArrayRef<PDLValue> args) { | ||
Attribute attr = args[0].cast<Attribute>(); | ||
IntegerAttr cAttr = cast<IntegerAttr>(attr); | ||
int64_t value = cAttr.getValue().getSExtValue(); | ||
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2)); | ||
return success(); | ||
} | ||
|
||
LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results, | ||
ArrayRef<PDLValue> args) { | ||
Attribute attr = args[0].cast<Attribute>(); | ||
IntegerAttr cAttr = cast<IntegerAttr>(attr); | ||
int64_t value = cAttr.getValue().getSExtValue(); | ||
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1)); | ||
return success(); | ||
} | ||
|
||
void registerNativeConstraints(RewritePatternSet &patterns) { | ||
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl); | ||
patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl); | ||
} | ||
|
||
struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> { | ||
using MulToAddPdllBase::MulToAddPdllBase; | ||
|
||
void runOnOperation() { | ||
mlir::RewritePatternSet patterns(&getContext()); | ||
populateGeneratedPDLLPatterns(patterns); | ||
registerNativeConstraints(patterns); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace tutorial | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ | ||
#define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ | ||
|
||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Parser/Parser.h" | ||
|
||
namespace mlir { | ||
namespace tutorial { | ||
|
||
#define GEN_PASS_DECL_MULTOADDPDLL | ||
#include "lib/Transform/Arith/Passes.h.inc" | ||
|
||
#include "lib/Transform/Arith/MulToAddPdll.h.inc" | ||
|
||
} // namespace tutorial | ||
} // namespace mlir | ||
|
||
#endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// RUN: tutorial-opt %s --mul-to-add-pdll | FileCheck %s | ||
|
||
func.func @just_power_of_two(%arg: i32) -> i32 { | ||
%0 = arith.constant 8 : i32 | ||
%1 = arith.muli %arg, %0 : i32 | ||
func.return %1 : i32 | ||
} | ||
|
||
// CHECK-LABEL: func.func @just_power_of_two( | ||
// CHECK-SAME: %[[ARG:.*]]: i32 | ||
// CHECK-SAME: ) -> i32 { | ||
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] | ||
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] | ||
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] | ||
// CHECK: return %[[SUM_2]] : i32 | ||
// CHECK: } | ||
|
||
|
||
func.func @power_of_two_plus_one(%arg: i32) -> i32 { | ||
%0 = arith.constant 9 : i32 | ||
%1 = arith.muli %arg, %0 : i32 | ||
func.return %1 : i32 | ||
} | ||
|
||
// CHECK-LABEL: func.func @power_of_two_plus_one( | ||
// CHECK-SAME: %[[ARG:.*]]: i32 | ||
// CHECK-SAME: ) -> i32 { | ||
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]] | ||
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]] | ||
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]] | ||
// CHECK: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[ARG]] | ||
// CHECK: return %[[SUM_3]] : i32 | ||
// CHECK: } |