Skip to content

Commit 3e78c65

Browse files
committed
[CIR] Streamline creation of mlir::IntegerAttrs using mlir::Builder
1 parent 9707b1d commit 3e78c65

File tree

10 files changed

+79
-89
lines changed

10 files changed

+79
-89
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,20 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
156156
llvm_unreachable("Zero initializer for given type is NYI");
157157
}
158158

159+
using mlir::OpBuilder::getIntegerAttr;
160+
161+
mlir::IntegerAttr getIntegerAttr(const llvm::APInt &value) {
162+
mlir::Type intType = getIntegerType(value.getBitWidth());
163+
return getIntegerAttr(intType, value);
164+
}
165+
159166
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,
160167
bool isVolatile = false, bool isNontemporal = false,
161168
uint64_t alignment = 0) {
162-
mlir::IntegerAttr intAttr;
163-
if (alignment)
164-
intAttr = mlir::IntegerAttr::get(
165-
mlir::IntegerType::get(ptr.getContext(), 64), alignment);
166-
169+
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
167170
return create<cir::LoadOp>(loc, ptr, /*isDeref=*/false, isVolatile,
168171
isNontemporal,
169-
/*alignment=*/intAttr,
172+
/*alignment=*/alignmentAttr,
170173
/*mem_order=*/
171174
cir::MemOrderAttr{},
172175
/*tbaa=*/cir::TBAAAttr{});
@@ -373,9 +376,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
373376
mlir::Type type, llvm::StringRef name,
374377
clang::CharUnits alignment,
375378
mlir::Value dynAllocSize) {
376-
auto alignmentIntAttr = getSizeFromCharUnits(getContext(), alignment);
377-
return createAlloca(loc, addrType, type, name, alignmentIntAttr,
378-
dynAllocSize);
379+
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
380+
return createAlloca(loc, addrType, type, name, alignmentAttr, dynAllocSize);
379381
}
380382

381383
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
@@ -387,8 +389,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
387389
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
388390
mlir::Type type, llvm::StringRef name,
389391
clang::CharUnits alignment) {
390-
auto alignmentIntAttr = getSizeFromCharUnits(getContext(), alignment);
391-
return createAlloca(loc, addrType, type, name, alignmentIntAttr);
392+
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
393+
return createAlloca(loc, addrType, type, name, alignmentAttr);
392394
}
393395

394396
mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global,
@@ -581,13 +583,28 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
581583
return OpBuilder::InsertPoint(block, block->begin());
582584
};
583585

584-
mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx,
585-
clang::CharUnits size) {
586-
// Note that mlir::IntegerType is used instead of cir::IntType here
587-
// because we don't need sign information for this to be useful, so keep
588-
// it simple.
589-
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
590-
size.getQuantity());
586+
//
587+
// Alignement and size helpers
588+
//
589+
590+
// Note that mlir::IntegerType is used instead of cir::IntType here because we
591+
// don't need sign information for these to be useful, so keep it simple.
592+
593+
// Fot 0 alignment, return an empty attribute.
594+
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment) {
595+
return getAlignmentAttr(alignment.getQuantity());
596+
}
597+
598+
mlir::IntegerAttr getAlignmentAttr(llvm::Align alignment) {
599+
return getAlignmentAttr(alignment.value());
600+
}
601+
602+
mlir::IntegerAttr getAlignmentAttr(int64_t alignment) {
603+
return alignment ? getI64IntegerAttr(alignment) : mlir::IntegerAttr();
604+
}
605+
606+
mlir::IntegerAttr getSizeFromCharUnits(clang::CharUnits size) {
607+
return getI64IntegerAttr(size.getQuantity());
591608
}
592609

593610
/// Create a do-while operation.
@@ -616,9 +633,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
616633
}
617634

618635
mlir::TypedAttr getConstPtrAttr(mlir::Type t, int64_t v) {
619-
auto val =
620-
mlir::IntegerAttr::get(mlir::IntegerType::get(t.getContext(), 64), v);
621-
return cir::ConstPtrAttr::get(t, val);
636+
return cir::ConstPtrAttr::get(t, getI64IntegerAttr(v));
622637
}
623638

624639
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
151151
llvm::ArrayRef<int64_t> indices) {
152152
llvm::SmallVector<mlir::Attribute> attrs;
153153
for (auto ind : indices) {
154-
auto a =
155-
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), ind);
156-
attrs.push_back(a);
154+
attrs.push_back(getI64IntegerAttr(ind));
157155
}
158156

159157
mlir::ArrayAttr arAttr = mlir::ArrayAttr::get(getContext(), attrs);
@@ -938,13 +936,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
938936
clang::CharUnits align = clang::CharUnits::One(),
939937
bool isVolatile = false, bool isNontemporal = false,
940938
cir::MemOrderAttr order = {}) {
941-
llvm::MaybeAlign mayAlign = align.getAsAlign();
942-
mlir::IntegerAttr alignAttr;
943-
if (mayAlign) {
944-
uint64_t alignment = mayAlign ? mayAlign->value() : 0;
945-
alignAttr = mlir::IntegerAttr::get(
946-
mlir::IntegerType::get(dst.getContext(), 64), alignment);
947-
}
939+
mlir::IntegerAttr alignAttr = getAlignmentAttr(align);
948940
return CIRBaseBuilderTy::createStore(loc, val, dst, isVolatile,
949941
isNontemporal, alignAttr, order);
950942
}

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,10 +1706,9 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
17061706
E->getArg(1)->getExprLoc(), FD, 1);
17071707
uint64_t size =
17081708
E->getArg(2)->EvaluateKnownConstInt(getContext()).getZExtValue();
1709-
builder.create<cir::MemCpyInlineOp>(
1710-
getLoc(E->getSourceRange()), dest.getPointer(), src.getPointer(),
1711-
mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 64),
1712-
size));
1709+
builder.create<cir::MemCpyInlineOp>(getLoc(E->getSourceRange()),
1710+
dest.getPointer(), src.getPointer(),
1711+
builder.getI64IntegerAttr(size));
17131712
// __builtin_memcpy_inline has no return value
17141713
return RValue::get(nullptr);
17151714
}
@@ -1788,10 +1787,8 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
17881787
E->getArg(2)->EvaluateKnownConstInt(getContext()).getZExtValue();
17891788
emitNonNullArgCheck(RValue::get(Dest.getPointer()), E->getArg(0)->getType(),
17901789
E->getArg(0)->getExprLoc(), FD, 0);
1791-
builder.createMemSetInline(
1792-
getLoc(E->getSourceRange()), Dest.getPointer(), ByteVal,
1793-
mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 64),
1794-
size));
1790+
builder.createMemSetInline(getLoc(E->getSourceRange()), Dest.getPointer(),
1791+
ByteVal, builder.getI64IntegerAttr(size));
17951792
// __builtin_memset_inline has no return value
17961793
return RValue::get(nullptr);
17971794
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,12 +1269,11 @@ class CIRGenFunction : public CIRGenTypeCache {
12691269
{
12701270
mlir::OpBuilder::InsertionGuard guard(builder);
12711271
builder.restoreInsertionPoint(OutermostConditional->getInsertPoint());
1272-
builder.createStore(
1273-
value.getLoc(), value, addr,
1274-
/*isVolatile=*/false, /*isNontemporal=*/false,
1275-
mlir::IntegerAttr::get(
1276-
mlir::IntegerType::get(value.getContext(), 64),
1277-
(uint64_t)addr.getAlignment().getAsAlign().value()));
1272+
mlir::IntegerAttr alignmentAttr =
1273+
builder.getAlignmentAttr(addr.getAlignment());
1274+
builder.createStore(value.getLoc(), value, addr,
1275+
/*isVolatile=*/false, /*isNontemporal=*/false,
1276+
alignmentAttr);
12781277
}
12791278
}
12801279

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,7 +3342,7 @@ void CIRGenModule::emitDeferred(unsigned recursionLimit) {
33423342
}
33433343

33443344
mlir::IntegerAttr CIRGenModule::getSize(CharUnits size) {
3345-
return builder.getSizeFromCharUnits(&getMLIRContext(), size);
3345+
return builder.getSizeFromCharUnits(size);
33463346
}
33473347

33483348
mlir::Operation *
@@ -4214,10 +4214,7 @@ mlir::ArrayAttr CIRGenModule::emitAnnotationArgs(const AnnotateAttr *attr) {
42144214
// Handle case which can be evaluated to some numbers, not only literals
42154215
const auto &ap = ce.getAPValueResult();
42164216
if (ap.isInt()) {
4217-
args.push_back(mlir::IntegerAttr::get(
4218-
mlir::IntegerType::get(&getMLIRContext(),
4219-
ap.getInt().getBitWidth()),
4220-
ap.getInt()));
4217+
args.push_back(builder.getIntegerAttr(ap.getInt()));
42214218
} else {
42224219
llvm_unreachable("NYI like float, fixed-point, array...");
42234220
}

clang/lib/CIR/CodeGen/ConstantInitBuilder.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ class ConstantAggregateBuilderBase {
193193

194194
/// Add a pointer of a specific type.
195195
void addPointer(cir::PointerType ptrTy, uint64_t value) {
196-
auto val = mlir::IntegerAttr::get(
197-
mlir::IntegerType::get(ptrTy.getContext(), 64), value);
196+
mlir::IntegerAttr val = Builder.builder.getI64IntegerAttr(value);
198197
add(cir::ConstPtrAttr::get(ptrTy, val));
199198
}
200199

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
186186
static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {
187187

188188
if (parser.parseOptionalKeyword("null").succeeded()) {
189-
value = mlir::IntegerAttr::get(
190-
mlir::IntegerType::get(parser.getContext(), 64), 0);
189+
value = parser.getBuilder().getI64IntegerAttr(0);
191190
return success();
192191
}
193192

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
960960
// Rewrite the sub in front of extensions/trunc
961961
if (rewriteSub) {
962962
index = rewriter.create<mlir::LLVM::SubOp>(
963-
index.getLoc(), index.getType(),
964-
rewriter.create<mlir::LLVM::ConstantOp>(
965-
index.getLoc(), index.getType(),
966-
mlir::IntegerAttr::get(index.getType(), 0)),
963+
index.getLoc(),
964+
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
965+
index.getType(), 0),
967966
index);
968967
rewriter.eraseOp(sub);
969968
}
@@ -1320,8 +1319,7 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
13201319
return mlir::success();
13211320
}
13221321
case cir::CastKind::ptr_to_bool: {
1323-
auto zero =
1324-
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), 0);
1322+
auto zero = rewriter.getI64IntegerAttr(0);
13251323
auto null = rewriter.create<cir::ConstantOp>(
13261324
src.getLoc(), cir::ConstPtrAttr::get(castOp.getSrc().getType(), zero));
13271325
rewriter.replaceOpWithNewOp<cir::CmpOp>(
@@ -2035,8 +2033,7 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
20352033
uint64_t numElements =
20362034
mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
20372035
mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
2038-
loc, llvmIndexType,
2039-
mlir::IntegerAttr::get(llvmIndexType, numElements - 1));
2036+
loc, llvmIndexType, numElements - 1);
20402037
mlir::Value maskVector =
20412038
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
20422039
for (uint64_t i = 0; i < numElements; ++i) {
@@ -2607,16 +2604,14 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
26072604
switch (op.getKind()) {
26082605
case cir::UnaryOpKind::Inc: {
26092606
assert(!IsVector && "++ not allowed on vector types");
2610-
auto One = rewriter.create<mlir::LLVM::ConstantOp>(
2611-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
2607+
auto One = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
26122608
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
26132609
op, llvmType, adaptor.getInput(), One, overflowFlags);
26142610
return mlir::success();
26152611
}
26162612
case cir::UnaryOpKind::Dec: {
26172613
assert(!IsVector && "-- not allowed on vector types");
2618-
auto One = rewriter.create<mlir::LLVM::ConstantOp>(
2619-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
2614+
auto One = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
26202615
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
26212616
op, llvmType, adaptor.getInput(), One, overflowFlags);
26222617
return mlir::success();
@@ -2630,10 +2625,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
26302625
if (IsVector)
26312626
Zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
26322627
else
2633-
Zero = rewriter.create<mlir::LLVM::ConstantOp>(
2634-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
2628+
Zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
26352629
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
2636-
op, llvmType, Zero, adaptor.getInput(), overflowFlags);
2630+
op, Zero, adaptor.getInput(), overflowFlags);
26372631
return mlir::success();
26382632
}
26392633
case cir::UnaryOpKind::Not: {
@@ -2644,8 +2638,8 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
26442638
// done. It requires a series of insertelement ops.
26452639
mlir::Type llvmElementType =
26462640
getTypeConverter()->convertType(elementType);
2647-
auto MinusOneInt = rewriter.create<mlir::LLVM::ConstantOp>(
2648-
loc, llvmElementType, mlir::IntegerAttr::get(llvmElementType, -1));
2641+
auto MinusOneInt =
2642+
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmElementType, -1);
26492643
minusOne = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmType);
26502644
auto NumElements = mlir::dyn_cast<cir::VectorType>(type).getSize();
26512645
for (uint64_t i = 0; i < NumElements; ++i) {
@@ -2655,11 +2649,10 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
26552649
loc, minusOne, MinusOneInt, indexValue);
26562650
}
26572651
} else {
2658-
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
2659-
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
2652+
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
26602653
}
2661-
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
2662-
op, llvmType, adaptor.getInput(), minusOne);
2654+
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
2655+
minusOne);
26632656
return mlir::success();
26642657
}
26652658
}
@@ -2707,9 +2700,8 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
27072700
case cir::UnaryOpKind::Not:
27082701
assert(!IsVector && "NYI: op! on vector mask");
27092702
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
2710-
op, llvmType, adaptor.getInput(),
2711-
rewriter.create<mlir::LLVM::ConstantOp>(
2712-
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)));
2703+
op, adaptor.getInput(),
2704+
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1));
27132705
return mlir::success();
27142706
default:
27152707
return op.emitError()
@@ -3671,14 +3663,14 @@ mlir::LogicalResult CIRToLLVMPtrDiffOpLowering::matchAndRewrite(
36713663
auto resultVal = diff.getResult();
36723664
if (typeSize != 1) {
36733665
auto typeSizeVal = rewriter.create<mlir::LLVM::ConstantOp>(
3674-
op.getLoc(), llvmDstTy, mlir::IntegerAttr::get(llvmDstTy, typeSize));
3666+
op.getLoc(), llvmDstTy, typeSize);
36753667

36763668
if (dstTy.isUnsigned())
3677-
resultVal = rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), llvmDstTy,
3678-
diff, typeSizeVal);
3669+
resultVal =
3670+
rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), diff, typeSizeVal);
36793671
else
3680-
resultVal = rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), llvmDstTy,
3681-
diff, typeSizeVal);
3672+
resultVal =
3673+
rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), diff, typeSizeVal);
36823674
}
36833675
rewriter.replaceOp(op, resultVal);
36843676
return mlir::success();

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
232232
int addend) {
233233
auto type = V.getType();
234234
auto c1 = rewriter->create<mlir::arith::ConstantOp>(
235-
loc, type, mlir::IntegerAttr::get(type, addend));
235+
loc, mlir::IntegerAttr::get(type, addend));
236236
return rewriter->create<mlir::arith::AddIOp>(loc, V, c1);
237237
}
238238

@@ -285,7 +285,7 @@ void SCFLoop::transferToSCFForOp() {
285285
auto loc = forOp.getLoc();
286286
auto type = lb.getType();
287287
auto step = rewriter->create<mlir::arith::ConstantOp>(
288-
loc, type, mlir::IntegerAttr::get(type, getStep()));
288+
loc, mlir::IntegerAttr::get(type, getStep()));
289289
auto scfForOp = rewriter->create<mlir::scf::ForOp>(loc, lb, ub, step);
290290
SmallVector<mlir::Value> bbArg;
291291
rewriter->eraseOp(&scfForOp.getBody()->back());

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -571,13 +571,13 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
571571
switch (op.getKind()) {
572572
case cir::UnaryOpKind::Inc: {
573573
auto One = rewriter.create<mlir::arith::ConstantOp>(
574-
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
574+
op.getLoc(), mlir::IntegerAttr::get(type, 1));
575575
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input, One);
576576
break;
577577
}
578578
case cir::UnaryOpKind::Dec: {
579579
auto One = rewriter.create<mlir::arith::ConstantOp>(
580-
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
580+
op.getLoc(), mlir::IntegerAttr::get(type, 1));
581581
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, input, One);
582582
break;
583583
}
@@ -587,13 +587,13 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
587587
}
588588
case cir::UnaryOpKind::Minus: {
589589
auto Zero = rewriter.create<mlir::arith::ConstantOp>(
590-
op.getLoc(), type, mlir::IntegerAttr::get(type, 0));
590+
op.getLoc(), mlir::IntegerAttr::get(type, 0));
591591
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, Zero, input);
592592
break;
593593
}
594594
case cir::UnaryOpKind::Not: {
595595
auto MinusOne = rewriter.create<mlir::arith::ConstantOp>(
596-
op.getLoc(), type, mlir::IntegerAttr::get(type, -1));
596+
op.getLoc(), mlir::IntegerAttr::get(type, -1));
597597
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, MinusOne,
598598
input);
599599
break;

0 commit comments

Comments
 (0)