[TOC]
In the previous chapter, we introduced the
dialect conversion framework and partially lowered
many of the Toy
operations to affine loop nests for optimization. In this
chapter, we will finally lower to LLVM for code generation.
For this lowering, we will again use the dialect conversion framework to perform
the heavy lifting. However, this time, we will be performing a full conversion
to the LLVM dialect. Thankfully, we have already
lowered all but one of the toy
operations, with the last being toy.print
.
Before going over the conversion to LLVM, let's lower the toy.print
operation.
We will lower this operation to a non-affine loop nest that invokes printf
for
each element. Note that, because the dialect conversion framework supports
transitive lowering, we don't need to
directly emit operations in the LLVM dialect. By transitive lowering, we mean
that the conversion framework may apply multiple patterns to fully legalize an
operation. In this example, we are generating a structured loop nest instead of
the branch-form in the LLVM dialect. As long as we then have a lowering from the
loop operations to LLVM, the lowering will still succeed.
During lowering we can get, or build, the declaration for printf as so:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
return SymbolRefAttr::get("printf", context);
}
Now that the lowering for the printf operation has been defined, we can specify the components necessary for the lowering. These are largely the same as the components defined in the previous chapter.
For this conversion, aside from the top-level module, we will be lowering everything to the LLVM dialect.
mlir::ConversionTarget target(getContext());
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
This lowering will also transform the MemRef types which are currently being operated on into a representation in LLVM. To perform this conversion, we use a TypeConverter as part of the lowering. This converter specifies how one type maps to another. This is necessary now that we are performing more complicated lowerings involving block arguments. Given that we don't have any Toy-dialect-specific types that need to be lowered, the default converter is enough for our use case.
LLVMTypeConverter typeConverter(&getContext());
Now that the conversion target has been defined, we need to provide the patterns
used for lowering. At this point in the compilation process, we have a
combination of toy
, affine
, and std
operations. Luckily, the std
and
affine
dialects already provide the set of patterns needed to transform them
into LLVM dialect. These patterns allow for lowering the IR in multiple stages
by relying on transitive lowering.
mlir::OwningRewritePatternList patterns;
mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
// PrintOp.
patterns.insert<PrintOpLowering>(&getContext());
We want to completely lower to LLVM, so we use a FullConversion
. This ensures
that only legal operations will remain after the conversion.
mlir::ModuleOp module = getModule();
if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
&typeConverter)))
signalPassFailure();
Looking back at our current working example:
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<3x2xf64>
%3 = "toy.mul"(%2, %2) : (tensor<3x2xf64>, tensor<3x2xf64>) -> tensor<3x2xf64>
"toy.print"(%3) : (tensor<3x2xf64>) -> ()
"toy.return"() : () -> ()
}
We can now lower down to the LLVM dialect, which produces the following code:
llvm.func @free(!llvm<"i8*">)
llvm.func @printf(!llvm<"i8*">, ...) -> !llvm.i32
llvm.func @malloc(!llvm.i64) -> !llvm<"i8*">
llvm.func @main() {
%0 = llvm.mlir.constant(1.000000e+00 : f64) : !llvm.double
%1 = llvm.mlir.constant(2.000000e+00 : f64) : !llvm.double
...
^bb16:
%221 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
%222 = llvm.mlir.constant(0 : index) : !llvm.i64
%223 = llvm.mlir.constant(2 : index) : !llvm.i64
%224 = llvm.mul %214, %223 : !llvm.i64
%225 = llvm.add %222, %224 : !llvm.i64
%226 = llvm.mlir.constant(1 : index) : !llvm.i64
%227 = llvm.mul %219, %226 : !llvm.i64
%228 = llvm.add %225, %227 : !llvm.i64
%229 = llvm.getelementptr %221[%228] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*">
%230 = llvm.load %229 : !llvm<"double*">
%231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32
%232 = llvm.add %219, %218 : !llvm.i64
llvm.br ^bb15(%232 : !llvm.i64)
...
^bb18:
%235 = llvm.extractvalue %65[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
%236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
llvm.call @free(%236) : (!llvm<"i8*">) -> ()
%237 = llvm.extractvalue %45[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
%238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
llvm.call @free(%238) : (!llvm<"i8*">) -> ()
%239 = llvm.extractvalue %25[0 : index] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
%240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
llvm.call @free(%240) : (!llvm<"i8*">) -> ()
llvm.return
}
See Conversion to the LLVM IR Dialect for more in-depth details on lowering to the LLVM dialect.
At this point we are right at the cusp of code generation. We can generate code in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to run it.
Now that our module is comprised only of operations in the LLVM dialect, we can export to LLVM IR. To do this programmatically, we can invoke the following utility:
std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
if (!llvmModule)
/* ... an error was encountered ... */
Exporting our module to LLVM IR generates:
define void @main() {
...
102:
%103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
%104 = mul i64 %96, 2
%105 = add i64 0, %104
%106 = mul i64 %100, 1
%107 = add i64 %105, %106
%108 = getelementptr double, double* %103, i64 %107
%109 = load double, double* %108
%110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
%111 = add i64 %100, 1
br label %99
...
115:
%116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
%117 = bitcast double* %116 to i8*
call void @free(i8* %117)
%118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
%119 = bitcast double* %118 to i8*
call void @free(i8* %119)
%120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
%121 = bitcast double* %120 to i8*
call void @free(i8* %121)
ret void
}
If we enable optimization on the generated LLVM IR, we can trim this down quite a bit:
define void @main()
%0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
%1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
%putchar = tail call i32 @putchar(i32 10)
%2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
%3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
%putchar.1 = tail call i32 @putchar(i32 10)
%4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
%5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
%putchar.2 = tail call i32 @putchar(i32 10)
ret void
}
The full code listing for dumping LLVM IR can be found in Ch6/toy.cpp
in the
dumpLLVMIR()
function:
int dumpLLVMIR(mlir::ModuleOp module) {
// Translate the module, that contains the LLVM dialect, to LLVM IR.
auto llvmModule = mlir::translateModuleToLLVMIR(module);
if (!llvmModule) {
llvm::errs() << "Failed to emit LLVM IR\n";
return -1;
}
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
/// Optionally run an optimization pipeline over the llvm module.
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {
llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
return -1;
}
llvm::errs() << *llvmModule << "\n";
return 0;
}
Setting up a JIT to run the module containing the LLVM dialect can be done using
the mlir::ExecutionEngine
infrastructure. This is a utility wrapper around
LLVM's JIT that accepts .mlir
as input. The full code listing for setting up
the JIT can be found in Ch6/toy.cpp
in the runJit()
function:
int runJit(mlir::ModuleOp module) {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// An optimization pipeline to use within the execution engine.
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
// the module.
auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
// Invoke the JIT-compiled function.
auto invocationResult = engine->invoke("main");
if (invocationResult) {
llvm::errs() << "JIT invocation failed\n";
return -1;
}
return 0;
}
You can play around with it from the build directory:
$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
1.000000 2.000000
3.000000 4.000000
You can also play with -emit=mlir
, -emit=mlir-affine
, -emit=mlir-llvm
, and
-emit=llvm
to compare the various levels of IR involved. Also try options like
--print-ir-after-all
to track the
evolution of the IR throughout the pipeline.
So far, we have worked with primitive data types. In the
next chapter, we will add a composite struct
type.