1 # Chapter 6: Lowering to LLVM and CodeGeneration
5 In the [previous chapter](Ch-5.md), we introduced the
6 [dialect conversion](../../DialectConversion.md) framework and partially lowered
7 many of the `Toy` operations to affine loop nests for optimization. In this
8 chapter, we will finally lower to LLVM for code generation.
12 For this lowering, we will again use the dialect conversion framework to perform
13 the heavy lifting. However, this time, we will be performing a full conversion
14 to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15 lowered all but one of the `toy` operations, with the last being `toy.print`.
16 Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17 We will lower this operation to a non-affine loop nest that invokes `printf` for
18 each element. Note that, because the dialect conversion framework supports
19 [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering),
20 we don't need to directly emit operations in the LLVM dialect. By transitive
21 lowering, we mean that the conversion framework may apply multiple patterns to
22 fully legalize an operation. In this example, we are generating a structured
23 loop nest instead of the branch-form in the LLVM dialect. As long as we then
24 have a lowering from the loop operations to LLVM, the lowering will still
27 During lowering we can get, or build, the declaration for printf as so:
30 /// Return a symbol reference to the printf function, inserting it into the
31 /// module if necessary.
32 static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
34 LLVM::LLVMDialect *llvmDialect) {
35 auto *context = module.getContext();
36 if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
37 return SymbolRefAttr::get("printf", context);
39 // Create a function declaration for printf, the signature is:
41 auto llvmI32Ty = IntegerType::get(context, 32);
43 LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
44 auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
47 // Insert the printf function into the body of the parent module.
48 PatternRewriter::InsertionGuard insertGuard(rewriter);
49 rewriter.setInsertionPointToStart(module.getBody());
50 rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
51 return SymbolRefAttr::get("printf", context);
55 Now that the lowering for the printf operation has been defined, we can specify
56 the components necessary for the lowering. These are largely the same as the
57 components defined in the [previous chapter](Ch-5.md).
61 For this conversion, aside from the top-level module, we will be lowering
62 everything to the LLVM dialect.
65 mlir::ConversionTarget target(getContext());
66 target.addLegalDialect<mlir::LLVMDialect>();
67 target.addLegalOp<mlir::ModuleOp>();
72 This lowering will also transform the MemRef types which are currently being
73 operated on into a representation in LLVM. To perform this conversion, we use a
74 TypeConverter as part of the lowering. This converter specifies how one type
75 maps to another. This is necessary now that we are performing more complicated
76 lowerings involving block arguments. Given that we don't have any
77 Toy-dialect-specific types that need to be lowered, the default converter is
78 enough for our use case.
81 LLVMTypeConverter typeConverter(&getContext());
84 ### Conversion Patterns
86 Now that the conversion target has been defined, we need to provide the patterns
87 used for lowering. At this point in the compilation process, we have a
88 combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the
89 `affine`, `arith`, and `std` dialects already provide the set of patterns needed
90 to transform them into LLVM dialect. These patterns allow for lowering the IR in
91 multiple stages by relying on
92 [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering).
95 mlir::RewritePatternSet patterns(&getContext());
96 mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
97 mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
98 mlir::arith::populateArithToLLVMConversionPatterns(typeConverter,
100 mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
101 mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
103 // The only remaining operation, to lower from the `toy` dialect, is the
105 patterns.add<PrintOpLowering>(&getContext());
110 We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
111 that only legal operations will remain after the conversion.
114 mlir::ModuleOp module = getOperation();
115 if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
119 Looking back at our current working example:
123 %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
124 %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
125 %3 = toy.mul %2, %2 : tensor<3x2xf64>
126 toy.print %3 : tensor<3x2xf64>
131 We can now lower down to the LLVM dialect, which produces the following code:
134 llvm.func @free(!llvm<"i8*">)
135 llvm.func @printf(!llvm<"i8*">, ...) -> i32
136 llvm.func @malloc(i64) -> !llvm<"i8*">
138 %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
139 %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64
144 %221 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
145 %222 = llvm.mlir.constant(0 : index) : i64
146 %223 = llvm.mlir.constant(2 : index) : i64
147 %224 = llvm.mul %214, %223 : i64
148 %225 = llvm.add %222, %224 : i64
149 %226 = llvm.mlir.constant(1 : index) : i64
150 %227 = llvm.mul %219, %226 : i64
151 %228 = llvm.add %225, %227 : i64
152 %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
153 %230 = llvm.load %229 : !llvm<"double*">
154 %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
155 %232 = llvm.add %219, %218 : i64
156 llvm.br ^bb15(%232 : i64)
161 %235 = llvm.extractvalue %65[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
162 %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
163 llvm.call @free(%236) : (!llvm<"i8*">) -> ()
164 %237 = llvm.extractvalue %45[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
165 %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
166 llvm.call @free(%238) : (!llvm<"i8*">) -> ()
167 %239 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
168 %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
169 llvm.call @free(%240) : (!llvm<"i8*">) -> ()
174 See [LLVM IR Target](../../TargetLLVMIR.md) for
175 more in-depth details on lowering to the LLVM dialect.
177 ## CodeGen: Getting Out of MLIR
179 At this point we are right at the cusp of code generation. We can generate code
180 in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
185 Now that our module is comprised only of operations in the LLVM dialect, we can
186 export to LLVM IR. To do this programmatically, we can invoke the following
190 std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
192 /* ... an error was encountered ... */
195 Exporting our module to LLVM IR generates:
198 define void @main() {
202 %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
203 %104 = mul i64 %96, 2
204 %105 = add i64 0, %104
205 %106 = mul i64 %100, 1
206 %107 = add i64 %105, %106
207 %108 = getelementptr double, double* %103, i64 %107
208 %109 = memref.load double, double* %108
209 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
210 %111 = add i64 %100, 1
216 %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
217 %117 = bitcast double* %116 to i8*
218 call void @free(i8* %117)
219 %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
220 %119 = bitcast double* %118 to i8*
221 call void @free(i8* %119)
222 %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
223 %121 = bitcast double* %120 to i8*
224 call void @free(i8* %121)
229 If we enable optimization on the generated LLVM IR, we can trim this down quite
234 %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
235 %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
236 %putchar = tail call i32 @putchar(i32 10)
237 %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
238 %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
239 %putchar.1 = tail call i32 @putchar(i32 10)
240 %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
241 %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
242 %putchar.2 = tail call i32 @putchar(i32 10)
247 The full code listing for dumping LLVM IR can be found in
248 `examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
252 int dumpLLVMIR(mlir::ModuleOp module) {
253 // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
254 // fresh LLVM IR context. (Note that LLVM is not thread-safe and any
255 // concurrent use of a context requires external locking.)
256 llvm::LLVMContext llvmContext;
257 auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
259 llvm::errs() << "Failed to emit LLVM IR\n";
263 // Initialize LLVM targets.
264 llvm::InitializeNativeTarget();
265 llvm::InitializeNativeTargetAsmPrinter();
266 mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
268 /// Optionally run an optimization pipeline over the llvm module.
269 auto optPipeline = mlir::makeOptimizingTransformer(
270 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
271 /*targetMachine=*/nullptr);
272 if (auto err = optPipeline(llvmModule.get())) {
273 llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
276 llvm::errs() << *llvmModule << "\n";
283 Setting up a JIT to run the module containing the LLVM dialect can be done using
284 the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
285 LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
286 the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
289 int runJit(mlir::ModuleOp module) {
290 // Initialize LLVM targets.
291 llvm::InitializeNativeTarget();
292 llvm::InitializeNativeTargetAsmPrinter();
294 // An optimization pipeline to use within the execution engine.
295 auto optPipeline = mlir::makeOptimizingTransformer(
296 /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
297 /*targetMachine=*/nullptr);
299 // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
301 auto maybeEngine = mlir::ExecutionEngine::create(module,
302 /*llvmModuleBuilder=*/nullptr, optPipeline);
303 assert(maybeEngine && "failed to construct an execution engine");
304 auto &engine = maybeEngine.get();
306 // Invoke the JIT-compiled function.
307 auto invocationResult = engine->invoke("main");
308 if (invocationResult) {
309 llvm::errs() << "JIT invocation failed\n";
317 You can play around with it from the build directory:
320 $ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
325 You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
326 `-emit=llvm` to compare the various levels of IR involved. Also try options like
327 [`--mlir-print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the
328 evolution of the IR throughout the pipeline.
330 The example code used throughout this section can be found in
331 test/Examples/Toy/Ch6/llvm-lowering.mlir.
333 So far, we have worked with primitive data types. In the
334 [next chapter](Ch-7.md), we will add a composite `struct` type.