[analyzer][NFC] Remove "V2" from ArrayBoundCheckerV2.cpp (#126094)
[llvm-project.git] / mlir / docs / Tutorials / Toy / Ch-6.md
blobe8a68b5f9ee38875d2eabfb22a3bee7d6785bc72
1 # Chapter 6: Lowering to LLVM and CodeGeneration
3 [TOC]
5 In the [previous chapter](Ch-5.md), we introduced the
6 [dialect conversion](../../DialectConversion.md) framework and partially lowered
7 many of the `Toy` operations to affine loop nests for optimization. In this
8 chapter, we will finally lower to LLVM for code generation.
10 ## Lowering to LLVM
12 For this lowering, we will again use the dialect conversion framework to perform
13 the heavy lifting. However, this time, we will be performing a full conversion
14 to the [LLVM dialect](../../Dialects/LLVM.md). Thankfully, we have already
15 lowered all but one of the `toy` operations, with the last being `toy.print`.
16 Before going over the conversion to LLVM, let's lower the `toy.print` operation.
17 We will lower this operation to a non-affine loop nest that invokes `printf` for
18 each element. Note that, because the dialect conversion framework supports
19 [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering),
20 we don't need to directly emit operations in the LLVM dialect. By transitive
21 lowering, we mean that the conversion framework may apply multiple patterns to
22 fully legalize an operation. In this example, we are generating a structured
23 loop nest instead of the branch-form in the LLVM dialect. As long as we then
24 have a lowering from the loop operations to LLVM, the lowering will still
25 succeed.
27 During lowering we can get, or build, the declaration for printf as so:
29 ```c++
30 /// Return a symbol reference to the printf function, inserting it into the
31 /// module if necessary.
32 static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
33                                            ModuleOp module,
34                                            LLVM::LLVMDialect *llvmDialect) {
35   auto *context = module.getContext();
36   if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
37     return SymbolRefAttr::get("printf", context);
39   // Create a function declaration for printf, the signature is:
40   //   * `i32 (i8*, ...)`
41   auto llvmI32Ty = IntegerType::get(context, 32);
42   auto llvmI8PtrTy =
43       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
44   auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
45                                                 /*isVarArg=*/true);
47   // Insert the printf function into the body of the parent module.
48   PatternRewriter::InsertionGuard insertGuard(rewriter);
49   rewriter.setInsertionPointToStart(module.getBody());
50   rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
51   return SymbolRefAttr::get("printf", context);
53 ```
55 Now that the lowering for the printf operation has been defined, we can specify
56 the components necessary for the lowering. These are largely the same as the
57 components defined in the [previous chapter](Ch-5.md).
59 ### Conversion Target
61 For this conversion, aside from the top-level module, we will be lowering
62 everything to the LLVM dialect.
64 ```c++
65   mlir::ConversionTarget target(getContext());
66   target.addLegalDialect<mlir::LLVMDialect>();
67   target.addLegalOp<mlir::ModuleOp>();
68 ```
70 ### Type Converter
72 This lowering will also transform the MemRef types which are currently being
73 operated on into a representation in LLVM. To perform this conversion, we use a
74 TypeConverter as part of the lowering. This converter specifies how one type
75 maps to another. This is necessary now that we are performing more complicated
76 lowerings involving block arguments. Given that we don't have any
77 Toy-dialect-specific types that need to be lowered, the default converter is
78 enough for our use case.
80 ```c++
81   LLVMTypeConverter typeConverter(&getContext());
82 ```
84 ### Conversion Patterns
86 Now that the conversion target has been defined, we need to provide the patterns
87 used for lowering. At this point in the compilation process, we have a
88 combination of `toy`, `affine`, `arith`, and `std` operations. Luckily, the
89 `affine`, `arith`, and `std` dialects already provide the set of patterns needed
90 to transform them into LLVM dialect. These patterns allow for lowering the IR in
91 multiple stages by relying on
92 [transitive lowering](../../../getting_started/Glossary.md/#transitive-lowering).
94 ```c++
95   mlir::RewritePatternSet patterns(&getContext());
96   mlir::populateAffineToStdConversionPatterns(patterns, &getContext());
97   mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext());
98   mlir::arith::populateArithToLLVMConversionPatterns(typeConverter,
99                                                           patterns);
100   mlir::populateFuncToLLVMConversionPatterns(typeConverter, patterns);
101   mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext());
103   // The only remaining operation, to lower from the `toy` dialect, is the
104   // PrintOp.
105   patterns.add<PrintOpLowering>(&getContext());
108 ### Full Lowering
110 We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
111 that only legal operations will remain after the conversion.
113 ```c++
114   mlir::ModuleOp module = getOperation();
115   if (mlir::failed(mlir::applyFullConversion(module, target, patterns)))
116     signalPassFailure();
119 Looking back at our current working example:
121 ```mlir
122 toy.func @main() {
123   %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
124   %2 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
125   %3 = toy.mul %2, %2 : tensor<3x2xf64>
126   toy.print %3 : tensor<3x2xf64>
127   toy.return
131 We can now lower down to the LLVM dialect, which produces the following code:
133 ```mlir
134 llvm.func @free(!llvm<"i8*">)
135 llvm.func @printf(!llvm<"i8*">, ...) -> i32
136 llvm.func @malloc(i64) -> !llvm<"i8*">
137 llvm.func @main() {
138   %0 = llvm.mlir.constant(1.000000e+00 : f64) : f64
139   %1 = llvm.mlir.constant(2.000000e+00 : f64) : f64
141   ...
143 ^bb16:
144   %221 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
145   %222 = llvm.mlir.constant(0 : index) : i64
146   %223 = llvm.mlir.constant(2 : index) : i64
147   %224 = llvm.mul %214, %223 : i64
148   %225 = llvm.add %222, %224 : i64
149   %226 = llvm.mlir.constant(1 : index) : i64
150   %227 = llvm.mul %219, %226 : i64
151   %228 = llvm.add %225, %227 : i64
152   %229 = llvm.getelementptr %221[%228] : (!llvm."double*">, i64) -> !llvm<"f64*">
153   %230 = llvm.load %229 : !llvm<"double*">
154   %231 = llvm.call @printf(%207, %230) : (!llvm<"i8*">, f64) -> i32
155   %232 = llvm.add %219, %218 : i64
156   llvm.br ^bb15(%232 : i64)
158   ...
160 ^bb18:
161   %235 = llvm.extractvalue %65[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
162   %236 = llvm.bitcast %235 : !llvm<"double*"> to !llvm<"i8*">
163   llvm.call @free(%236) : (!llvm<"i8*">) -> ()
164   %237 = llvm.extractvalue %45[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
165   %238 = llvm.bitcast %237 : !llvm<"double*"> to !llvm<"i8*">
166   llvm.call @free(%238) : (!llvm<"i8*">) -> ()
167   %239 = llvm.extractvalue %25[0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }">
168   %240 = llvm.bitcast %239 : !llvm<"double*"> to !llvm<"i8*">
169   llvm.call @free(%240) : (!llvm<"i8*">) -> ()
170   llvm.return
174 See [LLVM IR Target](../../TargetLLVMIR.md) for
175 more in-depth details on lowering to the LLVM dialect.
177 ## CodeGen: Getting Out of MLIR
179 At this point we are right at the cusp of code generation. We can generate code
180 in the LLVM dialect, so now we just need to export to LLVM IR and setup a JIT to
181 run it.
183 ### Emitting LLVM IR
185 Now that our module is comprised only of operations in the LLVM dialect, we can
186 export to LLVM IR. To do this programmatically, we can invoke the following
187 utility:
189 ```c++
190   std::unique_ptr<llvm::Module> llvmModule = mlir::translateModuleToLLVMIR(module);
191   if (!llvmModule)
192     /* ... an error was encountered ... */
195 Exporting our module to LLVM IR generates:
197 ```llvm
198 define void @main() {
199   ...
201 102:
202   %103 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
203   %104 = mul i64 %96, 2
204   %105 = add i64 0, %104
205   %106 = mul i64 %100, 1
206   %107 = add i64 %105, %106
207   %108 = getelementptr double, double* %103, i64 %107
208   %109 = memref.load double, double* %108
209   %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109)
210   %111 = add i64 %100, 1
211   cf.br label %99
213   ...
215 115:
216   %116 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %24, 0
217   %117 = bitcast double* %116 to i8*
218   call void @free(i8* %117)
219   %118 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %16, 0
220   %119 = bitcast double* %118 to i8*
221   call void @free(i8* %119)
222   %120 = extractvalue { double*, i64, [2 x i64], [2 x i64] } %8, 0
223   %121 = bitcast double* %120 to i8*
224   call void @free(i8* %121)
225   ret void
229 If we enable optimization on the generated LLVM IR, we can trim this down quite
230 a bit:
232 ```llvm
233 define void @main()
234   %0 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.000000e+00)
235   %1 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 1.600000e+01)
236   %putchar = tail call i32 @putchar(i32 10)
237   %2 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 4.000000e+00)
238   %3 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 2.500000e+01)
239   %putchar.1 = tail call i32 @putchar(i32 10)
240   %4 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 9.000000e+00)
241   %5 = tail call i32 (i8*, ...) @printf(i8* nonnull dereferenceable(1) getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double 3.600000e+01)
242   %putchar.2 = tail call i32 @putchar(i32 10)
243   ret void
247 The full code listing for dumping LLVM IR can be found in
248 `examples/toy/Ch6/toy.cpp` in the `dumpLLVMIR()` function:
250 ```c++
252 int dumpLLVMIR(mlir::ModuleOp module) {
253   // Translate the module, that contains the LLVM dialect, to LLVM IR. Use a
254   // fresh LLVM IR context. (Note that LLVM is not thread-safe and any
255   // concurrent use of a context requires external locking.)
256   llvm::LLVMContext llvmContext;
257   auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext);
258   if (!llvmModule) {
259     llvm::errs() << "Failed to emit LLVM IR\n";
260     return -1;
261   }
263   // Initialize LLVM targets.
264   llvm::InitializeNativeTarget();
265   llvm::InitializeNativeTargetAsmPrinter();
266   mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
268   /// Optionally run an optimization pipeline over the llvm module.
269   auto optPipeline = mlir::makeOptimizingTransformer(
270       /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
271       /*targetMachine=*/nullptr);
272   if (auto err = optPipeline(llvmModule.get())) {
273     llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
274     return -1;
275   }
276   llvm::errs() << *llvmModule << "\n";
277   return 0;
281 ### Setting up a JIT
283 Setting up a JIT to run the module containing the LLVM dialect can be done using
284 the `mlir::ExecutionEngine` infrastructure. This is a utility wrapper around
285 LLVM's JIT that accepts `.mlir` as input. The full code listing for setting up
286 the JIT can be found in `Ch6/toyc.cpp` in the `runJit()` function:
288 ```c++
289 int runJit(mlir::ModuleOp module) {
290   // Initialize LLVM targets.
291   llvm::InitializeNativeTarget();
292   llvm::InitializeNativeTargetAsmPrinter();
294   // An optimization pipeline to use within the execution engine.
295   auto optPipeline = mlir::makeOptimizingTransformer(
296       /*optLevel=*/EnableOpt ? 3 : 0, /*sizeLevel=*/0,
297       /*targetMachine=*/nullptr);
299   // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
300   // the module.
301   auto maybeEngine = mlir::ExecutionEngine::create(module,
302       /*llvmModuleBuilder=*/nullptr, optPipeline);
303   assert(maybeEngine && "failed to construct an execution engine");
304   auto &engine = maybeEngine.get();
306   // Invoke the JIT-compiled function.
307   auto invocationResult = engine->invoke("main");
308   if (invocationResult) {
309     llvm::errs() << "JIT invocation failed\n";
310     return -1;
311   }
313   return 0;
317 You can play around with it from the build directory:
319 ```shell
320 $ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch6 -emit=jit
321 1.000000 2.000000
322 3.000000 4.000000
325 You can also play with `-emit=mlir`, `-emit=mlir-affine`, `-emit=mlir-llvm`, and
326 `-emit=llvm` to compare the various levels of IR involved. Also try options like
327 [`--mlir-print-ir-after-all`](../../PassManagement.md/#ir-printing) to track the
328 evolution of the IR throughout the pipeline.
330 The example code used throughout this section can be found in
331 test/Examples/Toy/Ch6/llvm-lowering.mlir.
333 So far, we have worked with primitive data types. In the
334 [next chapter](Ch-7.md), we will add a composite `struct` type.