Skip to content

[CIR] Streamline creation of mlir::IntegerAttrs using mlir::Builder #1645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,20 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
llvm_unreachable("Zero initializer for given type is NYI");
}

using mlir::OpBuilder::getIntegerAttr;

mlir::IntegerAttr getIntegerAttr(const llvm::APInt &value) {
mlir::Type intType = getIntegerType(value.getBitWidth());
return getIntegerAttr(intType, value);
}

cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,
bool isVolatile = false, bool isNontemporal = false,
uint64_t alignment = 0) {
mlir::IntegerAttr intAttr;
if (alignment)
intAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(ptr.getContext(), 64), alignment);

mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
return create<cir::LoadOp>(loc, ptr, /*isDeref=*/false, isVolatile,
isNontemporal,
/*alignment=*/intAttr,
/*alignment=*/alignmentAttr,
/*mem_order=*/
cir::MemOrderAttr{},
/*tbaa=*/cir::TBAAAttr{});
Expand Down Expand Up @@ -373,9 +376,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::Type type, llvm::StringRef name,
clang::CharUnits alignment,
mlir::Value dynAllocSize) {
auto alignmentIntAttr = getSizeFromCharUnits(getContext(), alignment);
return createAlloca(loc, addrType, type, name, alignmentIntAttr,
dynAllocSize);
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
return createAlloca(loc, addrType, type, name, alignmentAttr, dynAllocSize);
}

mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
Expand All @@ -387,8 +389,8 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
mlir::Type type, llvm::StringRef name,
clang::CharUnits alignment) {
auto alignmentIntAttr = getSizeFromCharUnits(getContext(), alignment);
return createAlloca(loc, addrType, type, name, alignmentIntAttr);
mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
return createAlloca(loc, addrType, type, name, alignmentAttr);
}

mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global,
Expand Down Expand Up @@ -581,13 +583,28 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
return OpBuilder::InsertPoint(block, block->begin());
};

mlir::IntegerAttr getSizeFromCharUnits(mlir::MLIRContext *ctx,
clang::CharUnits size) {
// Note that mlir::IntegerType is used instead of cir::IntType here
// because we don't need sign information for this to be useful, so keep
// it simple.
return mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64),
size.getQuantity());
//
// Alignement and size helpers
//

// Note that mlir::IntegerType is used instead of cir::IntType here because we
// don't need sign information for these to be useful, so keep it simple.

// Fot 0 alignment, return an empty attribute.
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment) {
return getAlignmentAttr(alignment.getQuantity());
}

mlir::IntegerAttr getAlignmentAttr(llvm::Align alignment) {
return getAlignmentAttr(alignment.value());
}

mlir::IntegerAttr getAlignmentAttr(int64_t alignment) {
return alignment ? getI64IntegerAttr(alignment) : mlir::IntegerAttr();
}

mlir::IntegerAttr getSizeFromCharUnits(clang::CharUnits size) {
return getI64IntegerAttr(size.getQuantity());
}

/// Create a do-while operation.
Expand Down Expand Up @@ -616,9 +633,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}

mlir::TypedAttr getConstPtrAttr(mlir::Type t, int64_t v) {
auto val =
mlir::IntegerAttr::get(mlir::IntegerType::get(t.getContext(), 64), v);
return cir::ConstPtrAttr::get(t, val);
return cir::ConstPtrAttr::get(t, getI64IntegerAttr(v));
}

mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
Expand Down
12 changes: 2 additions & 10 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
llvm::ArrayRef<int64_t> indices) {
llvm::SmallVector<mlir::Attribute> attrs;
for (auto ind : indices) {
auto a =
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), ind);
attrs.push_back(a);
attrs.push_back(getI64IntegerAttr(ind));
}

mlir::ArrayAttr arAttr = mlir::ArrayAttr::get(getContext(), attrs);
Expand Down Expand Up @@ -938,13 +936,7 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
clang::CharUnits align = clang::CharUnits::One(),
bool isVolatile = false, bool isNontemporal = false,
cir::MemOrderAttr order = {}) {
llvm::MaybeAlign mayAlign = align.getAsAlign();
mlir::IntegerAttr alignAttr;
if (mayAlign) {
uint64_t alignment = mayAlign ? mayAlign->value() : 0;
alignAttr = mlir::IntegerAttr::get(
mlir::IntegerType::get(dst.getContext(), 64), alignment);
}
mlir::IntegerAttr alignAttr = getAlignmentAttr(align);
return CIRBaseBuilderTy::createStore(loc, val, dst, isVolatile,
isNontemporal, alignAttr, order);
}
Expand Down
13 changes: 5 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1706,10 +1706,9 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
E->getArg(1)->getExprLoc(), FD, 1);
uint64_t size =
E->getArg(2)->EvaluateKnownConstInt(getContext()).getZExtValue();
builder.create<cir::MemCpyInlineOp>(
getLoc(E->getSourceRange()), dest.getPointer(), src.getPointer(),
mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 64),
size));
builder.create<cir::MemCpyInlineOp>(getLoc(E->getSourceRange()),
dest.getPointer(), src.getPointer(),
builder.getI64IntegerAttr(size));
// __builtin_memcpy_inline has no return value
return RValue::get(nullptr);
}
Expand Down Expand Up @@ -1788,10 +1787,8 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
E->getArg(2)->EvaluateKnownConstInt(getContext()).getZExtValue();
emitNonNullArgCheck(RValue::get(Dest.getPointer()), E->getArg(0)->getType(),
E->getArg(0)->getExprLoc(), FD, 0);
builder.createMemSetInline(
getLoc(E->getSourceRange()), Dest.getPointer(), ByteVal,
mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 64),
size));
builder.createMemSetInline(getLoc(E->getSourceRange()), Dest.getPointer(),
ByteVal, builder.getI64IntegerAttr(size));
// __builtin_memset_inline has no return value
return RValue::get(nullptr);
}
Expand Down
11 changes: 5 additions & 6 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1269,12 +1269,11 @@ class CIRGenFunction : public CIRGenTypeCache {
{
mlir::OpBuilder::InsertionGuard guard(builder);
builder.restoreInsertionPoint(OutermostConditional->getInsertPoint());
builder.createStore(
value.getLoc(), value, addr,
/*isVolatile=*/false, /*isNontemporal=*/false,
mlir::IntegerAttr::get(
mlir::IntegerType::get(value.getContext(), 64),
(uint64_t)addr.getAlignment().getAsAlign().value()));
mlir::IntegerAttr alignmentAttr =
builder.getAlignmentAttr(addr.getAlignment());
builder.createStore(value.getLoc(), value, addr,
/*isVolatile=*/false, /*isNontemporal=*/false,
alignmentAttr);
}
}

Expand Down
7 changes: 2 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3342,7 +3342,7 @@ void CIRGenModule::emitDeferred(unsigned recursionLimit) {
}

mlir::IntegerAttr CIRGenModule::getSize(CharUnits size) {
return builder.getSizeFromCharUnits(&getMLIRContext(), size);
return builder.getSizeFromCharUnits(size);
}

mlir::Operation *
Expand Down Expand Up @@ -4214,10 +4214,7 @@ mlir::ArrayAttr CIRGenModule::emitAnnotationArgs(const AnnotateAttr *attr) {
// Handle case which can be evaluated to some numbers, not only literals
const auto &ap = ce.getAPValueResult();
if (ap.isInt()) {
args.push_back(mlir::IntegerAttr::get(
mlir::IntegerType::get(&getMLIRContext(),
ap.getInt().getBitWidth()),
ap.getInt()));
args.push_back(builder.getIntegerAttr(ap.getInt()));
} else {
llvm_unreachable("NYI like float, fixed-point, array...");
}
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/CIR/CodeGen/ConstantInitBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ class ConstantAggregateBuilderBase {

/// Add a pointer of a specific type.
void addPointer(cir::PointerType ptrTy, uint64_t value) {
auto val = mlir::IntegerAttr::get(
mlir::IntegerType::get(ptrTy.getContext(), 64), value);
mlir::IntegerAttr val = Builder.builder.getI64IntegerAttr(value);
add(cir::ConstPtrAttr::get(ptrTy, val));
}

Expand Down
3 changes: 1 addition & 2 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,
static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {

if (parser.parseOptionalKeyword("null").succeeded()) {
value = mlir::IntegerAttr::get(
mlir::IntegerType::get(parser.getContext(), 64), 0);
value = parser.getBuilder().getI64IntegerAttr(0);
return success();
}

Expand Down
50 changes: 21 additions & 29 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -960,10 +960,9 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
// Rewrite the sub in front of extensions/trunc
if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(), index.getType(),
rewriter.create<mlir::LLVM::ConstantOp>(
index.getLoc(), index.getType(),
mlir::IntegerAttr::get(index.getType(), 0)),
index.getLoc(),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
index.getType(), 0),
index);
rewriter.eraseOp(sub);
}
Expand Down Expand Up @@ -1320,8 +1319,7 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
return mlir::success();
}
case cir::CastKind::ptr_to_bool: {
auto zero =
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), 0);
auto zero = rewriter.getI64IntegerAttr(0);
auto null = rewriter.create<cir::ConstantOp>(
src.getLoc(), cir::ConstPtrAttr::get(castOp.getSrc().getType(), zero));
rewriter.replaceOpWithNewOp<cir::CmpOp>(
Expand Down Expand Up @@ -2035,8 +2033,7 @@ mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
uint64_t numElements =
mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmIndexType,
mlir::IntegerAttr::get(llvmIndexType, numElements - 1));
loc, llvmIndexType, numElements - 1);
mlir::Value maskVector =
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
for (uint64_t i = 0; i < numElements; ++i) {
Expand Down Expand Up @@ -2607,16 +2604,14 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
switch (op.getKind()) {
case cir::UnaryOpKind::Inc: {
assert(!IsVector && "++ not allowed on vector types");
auto One = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
auto One = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
op, llvmType, adaptor.getInput(), One, overflowFlags);
return mlir::success();
}
case cir::UnaryOpKind::Dec: {
assert(!IsVector && "-- not allowed on vector types");
auto One = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1));
auto One = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
op, llvmType, adaptor.getInput(), One, overflowFlags);
return mlir::success();
Expand All @@ -2630,10 +2625,9 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
if (IsVector)
Zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
else
Zero = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 0));
Zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
op, llvmType, Zero, adaptor.getInput(), overflowFlags);
op, Zero, adaptor.getInput(), overflowFlags);
return mlir::success();
}
case cir::UnaryOpKind::Not: {
Expand All @@ -2644,8 +2638,8 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
// done. It requires a series of insertelement ops.
mlir::Type llvmElementType =
getTypeConverter()->convertType(elementType);
auto MinusOneInt = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmElementType, mlir::IntegerAttr::get(llvmElementType, -1));
auto MinusOneInt =
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmElementType, -1);
minusOne = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmType);
auto NumElements = mlir::dyn_cast<cir::VectorType>(type).getSize();
for (uint64_t i = 0; i < NumElements; ++i) {
Expand All @@ -2655,11 +2649,10 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
loc, minusOne, MinusOneInt, indexValue);
}
} else {
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, -1));
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
}
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
op, llvmType, adaptor.getInput(), minusOne);
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
minusOne);
return mlir::success();
}
}
Expand Down Expand Up @@ -2707,9 +2700,8 @@ mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
case cir::UnaryOpKind::Not:
assert(!IsVector && "NYI: op! on vector mask");
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(
op, llvmType, adaptor.getInput(),
rewriter.create<mlir::LLVM::ConstantOp>(
loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)));
op, adaptor.getInput(),
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1));
return mlir::success();
default:
return op.emitError()
Expand Down Expand Up @@ -3671,14 +3663,14 @@ mlir::LogicalResult CIRToLLVMPtrDiffOpLowering::matchAndRewrite(
auto resultVal = diff.getResult();
if (typeSize != 1) {
auto typeSizeVal = rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), llvmDstTy, mlir::IntegerAttr::get(llvmDstTy, typeSize));
op.getLoc(), llvmDstTy, typeSize);

if (dstTy.isUnsigned())
resultVal = rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), llvmDstTy,
diff, typeSizeVal);
resultVal =
rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), diff, typeSizeVal);
else
resultVal = rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), llvmDstTy,
diff, typeSizeVal);
resultVal =
rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), diff, typeSizeVal);
}
rewriter.replaceOp(op, resultVal);
return mlir::success();
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
int addend) {
auto type = V.getType();
auto c1 = rewriter->create<mlir::arith::ConstantOp>(
loc, type, mlir::IntegerAttr::get(type, addend));
loc, mlir::IntegerAttr::get(type, addend));
return rewriter->create<mlir::arith::AddIOp>(loc, V, c1);
}

Expand Down Expand Up @@ -285,7 +285,7 @@ void SCFLoop::transferToSCFForOp() {
auto loc = forOp.getLoc();
auto type = lb.getType();
auto step = rewriter->create<mlir::arith::ConstantOp>(
loc, type, mlir::IntegerAttr::get(type, getStep()));
loc, mlir::IntegerAttr::get(type, getStep()));
auto scfForOp = rewriter->create<mlir::scf::ForOp>(loc, lb, ub, step);
SmallVector<mlir::Value> bbArg;
rewriter->eraseOp(&scfForOp.getBody()->back());
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,13 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
switch (op.getKind()) {
case cir::UnaryOpKind::Inc: {
auto One = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
op.getLoc(), mlir::IntegerAttr::get(type, 1));
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, type, input, One);
break;
}
case cir::UnaryOpKind::Dec: {
auto One = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
op.getLoc(), mlir::IntegerAttr::get(type, 1));
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, input, One);
break;
}
Expand All @@ -587,13 +587,13 @@ class CIRUnaryOpLowering : public mlir::OpConversionPattern<cir::UnaryOp> {
}
case cir::UnaryOpKind::Minus: {
auto Zero = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), type, mlir::IntegerAttr::get(type, 0));
op.getLoc(), mlir::IntegerAttr::get(type, 0));
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, type, Zero, input);
break;
}
case cir::UnaryOpKind::Not: {
auto MinusOne = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), type, mlir::IntegerAttr::get(type, -1));
op.getLoc(), mlir::IntegerAttr::get(type, -1));
rewriter.replaceOpWithNewOp<mlir::arith::XOrIOp>(op, type, MinusOne,
input);
break;
Expand Down
Loading