-
Notifications
You must be signed in to change notification settings - Fork 14k
switch type and value ordering for arith Constant[XX]Op
#144636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
switch type and value ordering for arith Constant[XX]Op
#144636
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-quant @llvm/pr-subscribers-mlir-gpu Author: Skrai Pardus (ashjeong) ChangesChange made to standardize with all other Full diff: https://github.com/llvm/llvm-project/pull/144636.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 77241319851e6..0bee876ac9bfa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -62,8 +62,8 @@ class ConstantIntOp : public arith::ConstantOp {
/// Build a constant int op that produces an integer of the specified type,
/// which must be an integer type.
- static void build(OpBuilder &builder, OperationState &result, int64_t value,
- Type type);
+ static void build(OpBuilder &builder, OperationState &result, Type type,
+ int64_t value);
inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -79,8 +79,8 @@ class ConstantFloatOp : public arith::ConstantOp {
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant float op that produces a float of the specified type.
- static void build(OpBuilder &builder, OperationState &result,
- const APFloat &value, FloatType type);
+ static void build(OpBuilder &builder, OperationState &result, FloatType type,
+ const APFloat &value);
inline APFloat value() {
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6d73f23e2aae1..923f5f67b865a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -244,11 +244,11 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// Clamp to the negation range.
Value min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
- intermediateType);
+ loc, intermediateType,
+ APInt::getSignedMinValue(inputBitWidth).getSExtValue());
Value max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
- intermediateType);
+ loc, intermediateType,
+ APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
// Truncate to the final value.
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 86f5e9baf4a94..c460a8bb2f4b2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1073,11 +1073,11 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
auto min = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
- accETy);
+ loc, accETy,
+ APInt::getSignedMinValue(outBitwidth).getSExtValue());
auto max = rewriter.create<arith::ConstantIntOp>(
- loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
- accETy);
+ loc, accETy,
+ APInt::getSignedMaxValue(outBitwidth).getSExtValue());
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
/*isUnsigned=*/false);
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 9e53e195274aa..b9f91a0509103 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -257,9 +257,7 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
}
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
- int64_t value, Type type) {
- assert(type.isSignlessInteger() &&
- "ConstantIntOp can only have signless integer type values");
+ Type type, int64_t value) {
arith::ConstantOp::build(builder, result, type,
builder.getIntegerAttr(type, value));
}
@@ -271,7 +269,7 @@ bool arith::ConstantIntOp::classof(Operation *op) {
}
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
- const APFloat &value, FloatType type) {
+ FloatType type, const APFloat &value) {
arith::ConstantOp::build(builder, result, type,
builder.getFloatAttr(type, value));
}
@@ -2363,7 +2361,7 @@ struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
rewriter.create<arith::XOrIOp>(
op.getLoc(), op.getCondition(),
rewriter.create<arith::ConstantIntOp>(
- op.getLoc(), 1, op.getCondition().getType())));
+ op.getLoc(), op.getCondition().getType(), 1)));
return success();
}
diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
index bb4807ab39cd6..3cd8684878a11 100644
--- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp
@@ -216,7 +216,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
from = b.create<arith::TruncFOp>(toFpTy, from);
}
Value zero = b.create<mlir::arith::ConstantFloatOp>(
- mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
return b.create<complex::CreateOp>(targetType, from, zero);
}
@@ -229,7 +229,7 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand,
from = b.create<arith::SIToFPOp>(toFpTy, from);
}
Value zero = b.create<mlir::arith::ConstantFloatOp>(
- mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
+ toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
return b.create<complex::CreateOp>(targetType, from, zero);
}
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 9c776dfa176a4..27fa92cee79c2 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -820,13 +820,13 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
const float initialOvershardingFactor = 8.0f;
Value scalingFactor = b.create<arith::ConstantFloatOp>(
- llvm::APFloat(initialOvershardingFactor), b.getF32Type());
+ b.getF32Type(), llvm::APFloat(initialOvershardingFactor));
for (const std::pair<int, float> &p : overshardingBrackets) {
Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
Value inBracket = b.create<arith::CmpIOp>(
arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
- llvm::APFloat(p.second), b.getF32Type());
+ b.getF32Type(), llvm::APFloat(p.second));
scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor,
scalingFactor);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index a75598afe8c72..d35f72e5a9e26 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -83,11 +83,11 @@ struct GpuAllReduceRewriter {
// Compute lane id (invocation id withing the subgroup).
Value subgroupMask =
- create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
+ create<arith::ConstantIntOp>(int32Type, kSubgroupSize - 1);
Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
Value isFirstLane =
create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
- create<arith::ConstantIntOp>(0, int32Type));
+ create<arith::ConstantIntOp>(int32Type, 0));
Value numThreadsWithSmallerSubgroupId =
create<arith::SubIOp>(invocationIdx, laneId);
@@ -282,7 +282,7 @@ struct GpuAllReduceRewriter {
/// The first lane returns the result, all others return values are undefined.
Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
AccumulatorFactory &accumFactory) {
- Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+ Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
activeWidth, subgroupSize);
std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
@@ -296,7 +296,7 @@ struct GpuAllReduceRewriter {
// lane is within the active range. The accumulated value is available
// in the first lane.
for (int i = 1; i < kSubgroupSize; i <<= 1) {
- Value offset = create<arith::ConstantIntOp>(i, int32Type);
+ Value offset = create<arith::ConstantIntOp>(int32Type, i);
auto shuffleOp = create<gpu::ShuffleOp>(
shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
// Skip the accumulation if the shuffle op read from a lane outside
@@ -318,7 +318,7 @@ struct GpuAllReduceRewriter {
[&] {
Value value = operand;
for (int i = 1; i < kSubgroupSize; i <<= 1) {
- Value offset = create<arith::ConstantIntOp>(i, int32Type);
+ Value offset = create<arith::ConstantIntOp>(int32Type, i);
auto shuffleOp =
create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
gpu::ShuffleMode::XOR);
@@ -331,7 +331,7 @@ struct GpuAllReduceRewriter {
/// Returns value divided by the subgroup size (i.e. 32).
Value getDivideBySubgroupSize(Value value) {
- Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
+ Value subgroupSize = create<arith::ConstantIntOp>(int32Type, kSubgroupSize);
return create<arith::DivSIOp>(int32Type, value, subgroupSize);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 999359c7fa872..1419175304899 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) {
assert(elementType.isIntOrIndexOrFloat() &&
"expected scalar type while computing zero value");
if (isa<IntegerType>(elementType))
- return b.create<arith::ConstantIntOp>(loc, 0, elementType);
+ return b.create<arith::ConstantIntOp>(loc, elementType, 0);
if (elementType.isIndex())
return b.create<arith::ConstantIndexOp>(loc, 0);
// Assume float.
auto floatType = cast<FloatType>(elementType);
return b.create<arith::ConstantFloatOp>(
- loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);
+ loc, floatType, APFloat::getZero(floatType.getFloatSemantics()));
}
GenericOp
diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
index c2dbcde1aeba6..793db73575b4f 100644
--- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp
@@ -315,9 +315,9 @@ Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
auto inputType = input.getType();
auto storageType = quantizedType.getStorageType();
auto storageMinScalar = builder.create<arith::ConstantIntOp>(
- loc, quantizedType.getStorageTypeMin(), storageType);
+ loc, storageType, quantizedType.getStorageTypeMin());
auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
- loc, quantizedType.getStorageTypeMax(), storageType);
+ loc, storageType, quantizedType.getStorageTypeMax());
auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
inputType, inputShape);
auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index ed73d81198f29..66f7bc27f82ff 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -141,7 +141,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes,
b.setInsertionPointToStart(innerLoop.getBody());
// Insert in-bound check
Value inbound =
- b.create<arith::ConstantIntOp>(op.getLoc(), 1, b.getIntegerType(1));
+ b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1);
for (auto [outerUpperBound, outerIV, innerIV, innerStep] :
llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(),
innerLoop.getInductionVars(), innerLoop.getStep())) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index ebe718ae4fb61..29d6d2574a2be 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -240,7 +240,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
if (isa<IndexType>(step.getType())) {
one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
} else {
- one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+ one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1);
}
Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please elaborate what "other constructors" you refer to in the commit message. LLVM doesn't insist on API stability, but it's unclear why this specific API is suddenly a problem after five years of existence.
@ftynse In all instances of generated code for the MLIR dialects's Ops (that is the TableGen using the .td files to create the .h.inc/.cpp.inc files), the desired result type is always specified before the value. Examples:
The constructors of most ops are generated like so. When I was trying to use |
Change made to standardize with all other
Op
build()
constructors.