1 //===-- LLJITWithOptimizingIRTransform.cpp -- LLJIT with IR optimization --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // In this example we will use an IR transform to optimize a module as it
10 // passes through LLJIT's IRTransformLayer.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
15 #include "llvm/IR/LegacyPassManager.h"
16 #include "llvm/Support/InitLLVM.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Support/raw_ostream.h"
19 #include "llvm/Transforms/IPO.h"
20 #include "llvm/Transforms/Scalar.h"
22 #include "../ExampleModules.h"
25 using namespace llvm::orc
;
27 ExitOnError ExitOnErr
;
31 // This IR contains a recursive definition of the factorial function:
33 // fac(n) | n == 0 = 1
34 // | otherwise = n * fac(n - 1)
36 // It also contains an entry function which calls the factorial function with
37 // an input value of 5.
39 // We expect the IR optimization transform that we build below to transform
40 // this into a non-recursive factorial function and an entry function that
41 // returns a constant value of 5!, or 120.
43 const llvm::StringRef MainMod
=
46 define i32 @fac(i32 %n) {
48 %tobool = icmp eq i32 %n, 0
49 br i1 %tobool, label %return, label %if.then
51 if.then: ; preds = %entry
52 %arg = add nsw i32 %n, -1
53 %call_result = call i32 @fac(i32 %arg)
54 %result = mul nsw i32 %n, %call_result
57 return: ; preds = %entry, %if.then
58 %final_result = phi i32 [ %result, %if.then ], [ 1, %entry ]
64 %result = call i32 @fac(i32 5)
70 // A function object that creates a simple pass pipeline to apply to each
71 // module as it passes through the IRTransformLayer.
72 class MyOptimizationTransform
{
74 MyOptimizationTransform() : PM(std::make_unique
<legacy::PassManager
>()) {
75 PM
->add(createTailCallEliminationPass());
76 PM
->add(createFunctionInliningPass());
77 PM
->add(createIndVarSimplifyPass());
78 PM
->add(createCFGSimplificationPass());
81 Expected
<ThreadSafeModule
> operator()(ThreadSafeModule TSM
,
82 MaterializationResponsibility
&R
) {
83 TSM
.withModuleDo([this](Module
&M
) {
84 dbgs() << "--- BEFORE OPTIMIZATION ---\n" << M
<< "\n";
86 dbgs() << "--- AFTER OPTIMIZATION ---\n" << M
<< "\n";
88 return std::move(TSM
);
92 std::unique_ptr
<legacy::PassManager
> PM
;
95 int main(int argc
, char *argv
[]) {
97 InitLLVM
X(argc
, argv
);
99 InitializeNativeTarget();
100 InitializeNativeTargetAsmPrinter();
102 ExitOnErr
.setBanner(std::string(argv
[0]) + ": ");
104 // (1) Create LLJIT instance.
105 auto J
= ExitOnErr(LLJITBuilder().create());
107 // (2) Install transform to optimize modules when they're materialized.
108 J
->getIRTransformLayer().setTransform(MyOptimizationTransform());
111 ExitOnErr(J
->addIRModule(ExitOnErr(parseExampleModule(MainMod
, "MainMod"))));
113 // (4) Look up the JIT'd function and call it.
114 auto EntrySym
= ExitOnErr(J
->lookup("entry"));
115 auto *Entry
= (int (*)())EntrySym
.getAddress();
117 int Result
= Entry();
118 outs() << "--- Result ---\n"
119 << "entry() = " << Result
<< "\n";