Skip to content

Commit c842705

Browse files
authored
[CIR] Streamline creation of mlir::IntegerAttrs using mlir::Builder (#141830)
- Uses getI<bitwidth>IntegerAttr builder method instead of explicit attribute and its type creation. - Adds few helper functions `getAlignmentAttr` to build alignment representing mlir::IntegerAttr. - Removes duplicit type parameters, that are inferred from mlir::IntegerAttr. This mirrors incubator changes from llvm/clangir#1645 (comment)
1 parent baeec97 commit c842705

File tree

6 files changed

+53
-61
lines changed

6 files changed

+53
-61
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/Support/ErrorHandling.h"
1919

2020
#include "mlir/IR/Builders.h"
21+
#include "mlir/IR/BuiltinAttributes.h"
2122
#include "mlir/IR/BuiltinTypes.h"
2223
#include "mlir/IR/Location.h"
2324
#include "mlir/IR/Types.h"
@@ -167,9 +168,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
167168
}
168169

169170
mlir::TypedAttr getConstPtrAttr(mlir::Type type, int64_t value) {
170-
auto valueAttr = mlir::IntegerAttr::get(
171-
mlir::IntegerType::get(type.getContext(), 64), value);
172-
return cir::ConstPtrAttr::get(type, valueAttr);
171+
return cir::ConstPtrAttr::get(type, getI64IntegerAttr(value));
173172
}
174173

175174
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
@@ -197,14 +196,9 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
197196

198197
mlir::Value createDummyValue(mlir::Location loc, mlir::Type type,
199198
clang::CharUnits alignment) {
200-
auto addr = createAlloca(loc, getPointerTo(type), type, {},
201-
getSizeFromCharUnits(getContext(), alignment));
202-
mlir::IntegerAttr alignAttr;
203-
uint64_t align = alignment.getQuantity();
204-
if (align)
205-
alignAttr = getI64IntegerAttr(align);
206-
207-
return create<cir::LoadOp>(loc, addr, /*isDeref=*/false, alignAttr);
199+
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
200+
auto addr = createAlloca(loc, getPointerTo(type), type, {}, alignmentAttr);
201+
return create<cir::LoadOp>(loc, addr, /*isDeref=*/false, alignmentAttr);
208202
}
209203

210204
cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
@@ -428,13 +422,29 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
428422
return OpBuilder::InsertPoint(block, block->begin());
429423
};
430424

431-
mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx,
432-
clang::CharUnits size) {
433-
// Note that mlir::IntegerType is used instead of cir::IntType here
434-
// because we don't need sign information for this to be useful, so keep
435-
// it simple.
436-
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
437-
size.getQuantity());
425+
//
426+
// Alignment and size helpers
427+
//
428+
429+
// Note that mlir::IntegerType is used instead of cir::IntType here because we
430+
// don't need sign information for these to be useful, so keep it simple.
431+
432+
// For 0 alignment, any overload of `getAlignmentAttr` returns an empty
433+
// attribute.
434+
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment) {
435+
return getAlignmentAttr(alignment.getQuantity());
436+
}
437+
438+
mlir::IntegerAttr getAlignmentAttr(llvm::Align alignment) {
439+
return getAlignmentAttr(alignment.value());
440+
}
441+
442+
mlir::IntegerAttr getAlignmentAttr(int64_t alignment) {
443+
return alignment ? getI64IntegerAttr(alignment) : mlir::IntegerAttr();
444+
}
445+
446+
mlir::IntegerAttr getSizeFromCharUnits(clang::CharUnits size) {
447+
return getI64IntegerAttr(size.getQuantity());
438448
}
439449

440450
/// Create a loop condition.

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -282,22 +282,15 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
282282

283283
cir::LoadOp createLoad(mlir::Location loc, Address addr,
284284
bool isVolatile = false) {
285-
mlir::IntegerAttr align;
286-
uint64_t alignment = addr.getAlignment().getQuantity();
287-
if (alignment)
288-
align = getI64IntegerAttr(alignment);
285+
mlir::IntegerAttr align = getAlignmentAttr(addr.getAlignment());
289286
return create<cir::LoadOp>(loc, addr.getPointer(), /*isDeref=*/false,
290287
align);
291288
}
292289

293290
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, Address dst,
294-
::mlir::IntegerAttr align = {}) {
295-
if (!align) {
296-
uint64_t alignment = dst.getAlignment().getQuantity();
297-
if (alignment)
298-
align = mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64),
299-
alignment);
300-
}
291+
mlir::IntegerAttr align = {}) {
292+
if (!align)
293+
align = getAlignmentAttr(dst.getAlignment());
301294
return CIRBaseBuilderTy::createStore(loc, val, dst.getPointer(), align);
302295
}
303296

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class CIRGenModule : public CIRGenTypeCache {
212212
const clang::FunctionDecl *funcDecl);
213213

214214
mlir::IntegerAttr getSize(CharUnits size) {
215-
return builder.getSizeFromCharUnits(&getMLIRContext(), size);
215+
return builder.getSizeFromCharUnits(size);
216216
}
217217

218218
const llvm::Triple &getTriple() const { return target.getTriple(); }

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {
6464
static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
6565

6666
if (parser.parseOptionalKeyword("null").succeeded()) {
67-
value = mlir::IntegerAttr::get(
68-
mlir::IntegerType::get(parser.getContext(), 64), 0);
67+
value = parser.getBuilder().getI64IntegerAttr(0);
6968
return success();
7069
}
7170

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
417417
case cir::CastKind::int_to_bool: {
418418
mlir::Value llvmSrcVal = adaptor.getOperands().front();
419419
mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>(
420-
castOp.getLoc(), llvmSrcVal.getType(),
421-
mlir::IntegerAttr::get(llvmSrcVal.getType(), 0));
420+
castOp.getLoc(), llvmSrcVal.getType(), 0);
422421
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
423422
castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt);
424423
break;
@@ -630,9 +629,8 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
630629
if (rewriteSub) {
631630
index = rewriter.create<mlir::LLVM::SubOp>(
632631
index.getLoc(), index.getType(),
633-
rewriter.create<mlir::LLVM::ConstantOp>(
634-
index.getLoc(), index.getType(),
635-
mlir::IntegerAttr::get(index.getType(), 0)),
632+
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
633+
index.getType(), 0),
636634
index);
637635
rewriter.eraseOp(sub);
638636
}
@@ -648,8 +646,7 @@ mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite(
648646
mlir::ConversionPatternRewriter &rewriter) const {
649647
assert(!cir::MissingFeatures::opAllocaDynAllocSize());
650648
mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
651-
op.getLoc(), typeConverter->convertType(rewriter.getIndexType()),
652-
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
649+
op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1);
653650
mlir::Type elementTy =
654651
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType());
655652
mlir::Type resultTy = convertTypeForMemory(*getTypeConverter(), dataLayout,
@@ -1111,18 +1108,16 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
11111108
switch (op.getKind()) {
11121109
case cir::UnaryOpKind::Inc: {
11131110
assert(!isVector && "++ not allowed on vector types");
1114-
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1115-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
1111+
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
11161112
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
11171113
op, llvmType, adaptor.getInput(), one, maybeNSW);
11181114
return mlir::success();
11191115
}
11201116
case cir::UnaryOpKind::Dec: {
11211117
assert(!isVector && "-- not allowed on vector types");
1122-
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1123-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
1124-
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
1125-
op, llvmType, adaptor.getInput(), one, maybeNSW);
1118+
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1119+
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(),
1120+
one, maybeNSW);
11261121
return mlir::success();
11271122
}
11281123
case cir::UnaryOpKind::Plus:
@@ -1133,10 +1128,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
11331128
if (isVector)
11341129
zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
11351130
else
1136-
zero = rewriter.create<mlir::LLVM::ConstantOp>(
1137-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
1131+
zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
11381132
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
1139-
op, llvmType, zero, adaptor.getInput(), maybeNSW);
1133+
op, zero, adaptor.getInput(), maybeNSW);
11401134
return mlir::success();
11411135
}
11421136
case cir::UnaryOpKind::Not: {
@@ -1150,11 +1144,10 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
11501144
minusOne =
11511145
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
11521146
} else {
1153-
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1154-
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
1147+
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
11551148
}
1156-
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
1157-
op, llvmType, adaptor.getInput(), minusOne);
1149+
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1150+
minusOne);
11581151
return mlir::success();
11591152
}
11601153
}
@@ -1206,10 +1199,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
12061199
return op.emitError() << "Unsupported unary operation on boolean type";
12071200
case cir::UnaryOpKind::Not: {
12081201
assert(!isVector && "NYI: op! on vector mask");
1209-
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1210-
loc, llvmType, rewriter.getIntegerAttr(llvmType, 1));
1211-
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmType,
1212-
adaptor.getInput(), one);
1202+
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1203+
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1204+
one);
12131205
return mlir::success();
12141206
}
12151207
}

clang/unittests/CIR/PointerLikeTest.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,10 @@ class CIROpenACCPointerLikeTest : public ::testing::Test {
4747
llvm::StringMap<unsigned> recordNames;
4848

4949
mlir::IntegerAttr getAlignOne(mlir::MLIRContext *ctx) {
50-
// Note that mlir::IntegerType is used instead of cir::IntType here
51-
// because we don't need sign information for this to be useful, so keep
52-
// it simple.
50+
// Note that mlir::IntegerType is used instead of cir::IntType here because
51+
// we don't need sign information for this to be useful, so keep it simple.
5352
clang::CharUnits align = clang::CharUnits::One();
54-
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
55-
align.getQuantity());
53+
return b.getI64IntegerAttr(align.getQuantity());
5654
}
5755

5856
mlir::StringAttr getUniqueRecordName(const std::string &baseName) {

0 commit comments

Comments
 (0)