[ThinLTO] Add code comment. NFC
[llvm-complete.git] / examples / Kaleidoscope / BuildingAJIT / Chapter5 / KaleidoscopeJIT.h
blob1d9c98a9d72ad0ce2f28300710ff4e09d0b0ab0e
1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
11 //===----------------------------------------------------------------------===//
13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
16 #include "RemoteJITUtils.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/ExecutionEngine/ExecutionEngine.h"
21 #include "llvm/ExecutionEngine/JITSymbol.h"
22 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
26 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
27 #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h"
28 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/LegacyPassManager.h"
31 #include "llvm/IR/Mangler.h"
32 #include "llvm/Support/DynamicLibrary.h"
33 #include "llvm/Support/Error.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Transforms/InstCombine/InstCombine.h"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/Transforms/Scalar/GVN.h"
39 #include <algorithm>
40 #include <cassert>
41 #include <cstdlib>
42 #include <map>
43 #include <memory>
44 #include <string>
45 #include <vector>
47 class PrototypeAST;
48 class ExprAST;
50 /// FunctionAST - This class represents a function definition itself.
51 class FunctionAST {
52 std::unique_ptr<PrototypeAST> Proto;
53 std::unique_ptr<ExprAST> Body;
55 public:
56 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
57 std::unique_ptr<ExprAST> Body)
58 : Proto(std::move(Proto)), Body(std::move(Body)) {}
60 const PrototypeAST& getProto() const;
61 const std::string& getName() const;
62 llvm::Function *codegen();
65 /// This will compile FnAST to IR, rename the function to add the given
66 /// suffix (needed to prevent a name-clash with the function's stub),
67 /// and then take ownership of the module that the function was compiled
68 /// into.
69 std::unique_ptr<llvm::Module>
70 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix);
72 namespace llvm {
73 namespace orc {
75 // Typedef the remote-client API.
76 using MyRemote = remote::OrcRemoteTargetClient;
78 class KaleidoscopeJIT {
79 private:
80 ExecutionSession &ES;
81 std::shared_ptr<SymbolResolver> Resolver;
82 std::unique_ptr<TargetMachine> TM;
83 const DataLayout DL;
84 LegacyRTDyldObjectLinkingLayer ObjectLayer;
85 LegacyIRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
87 using OptimizeFunction =
88 std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;
90 LegacyIRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer;
92 JITCompileCallbackManager *CompileCallbackMgr;
93 std::unique_ptr<IndirectStubsManager> IndirectStubsMgr;
94 MyRemote &Remote;
96 public:
97 KaleidoscopeJIT(ExecutionSession &ES, MyRemote &Remote)
98 : ES(ES),
99 Resolver(createLegacyLookupResolver(
101 [this](const std::string &Name) -> JITSymbol {
102 if (auto Sym = IndirectStubsMgr->findStub(Name, false))
103 return Sym;
104 if (auto Sym = OptimizeLayer.findSymbol(Name, false))
105 return Sym;
106 else if (auto Err = Sym.takeError())
107 return std::move(Err);
108 if (auto Addr = cantFail(this->Remote.getSymbolAddress(Name)))
109 return JITSymbol(Addr, JITSymbolFlags::Exported);
110 return nullptr;
112 [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })),
113 TM(EngineBuilder().selectTarget(Triple(Remote.getTargetTriple()), "",
114 "", SmallVector<std::string, 0>())),
115 DL(TM->createDataLayout()),
116 ObjectLayer(AcknowledgeORCv1Deprecation, ES,
117 [this](VModuleKey K) {
118 return LegacyRTDyldObjectLinkingLayer::Resources{
119 cantFail(this->Remote.createRemoteMemoryManager()),
120 Resolver};
122 CompileLayer(AcknowledgeORCv1Deprecation, ObjectLayer,
123 SimpleCompiler(*TM)),
124 OptimizeLayer(AcknowledgeORCv1Deprecation, CompileLayer,
125 [this](std::unique_ptr<Module> M) {
126 return optimizeModule(std::move(M));
128 Remote(Remote) {
129 auto CCMgrOrErr = Remote.enableCompileCallbacks(0);
130 if (!CCMgrOrErr) {
131 logAllUnhandledErrors(CCMgrOrErr.takeError(), errs(),
132 "Error enabling remote compile callbacks:");
133 exit(1);
135 CompileCallbackMgr = &*CCMgrOrErr;
136 IndirectStubsMgr = cantFail(Remote.createIndirectStubsManager());
137 llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
140 TargetMachine &getTargetMachine() { return *TM; }
142 VModuleKey addModule(std::unique_ptr<Module> M) {
143 // Add the module with a new VModuleKey.
144 auto K = ES.allocateVModule();
145 cantFail(OptimizeLayer.addModule(K, std::move(M)));
146 return K;
149 Error addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
150 // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support
151 // capture-by-move, which is be required for unique_ptr.
152 auto SharedFnAST = std::shared_ptr<FunctionAST>(std::move(FnAST));
154 // Set the action to compile our AST. This lambda will be run if/when
155 // execution hits the compile callback (via the stub).
157 // The steps to compile are:
158 // (1) IRGen the function.
159 // (2) Add the IR module to the JIT to make it executable like any other
160 // module.
161 // (3) Use findSymbol to get the address of the compiled function.
162 // (4) Update the stub pointer to point at the implementation so that
163 /// subsequent calls go directly to it and bypass the compiler.
164 // (5) Return the address of the implementation: this lambda will actually
165 // be run inside an attempted call to the function, and we need to
166 // continue on to the implementation to complete the attempted call.
167 // The JIT runtime (the resolver block) will use the return address of
168 // this function as the address to continue at once it has reset the
169 // CPU state to what it was immediately before the call.
170 auto CompileAction = [this, SharedFnAST]() {
171 auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl");
172 addModule(std::move(M));
173 auto Sym = findSymbol(SharedFnAST->getName() + "$impl");
174 assert(Sym && "Couldn't find compiled function?");
175 JITTargetAddress SymAddr = cantFail(Sym.getAddress());
176 if (auto Err = IndirectStubsMgr->updatePointer(
177 mangle(SharedFnAST->getName()), SymAddr)) {
178 logAllUnhandledErrors(std::move(Err), errs(),
179 "Error updating function pointer: ");
180 exit(1);
183 return SymAddr;
186 // Create a CompileCallback suing the CompileAction - this is the re-entry
187 // point into the compiler for functions that haven't been compiled yet.
188 auto CCAddr = cantFail(
189 CompileCallbackMgr->getCompileCallback(std::move(CompileAction)));
191 // Create an indirect stub. This serves as the functions "canonical
192 // definition" - an unchanging (constant address) entry point to the
193 // function implementation.
194 // Initially we point the stub's function-pointer at the compile callback
195 // that we just created. In the compile action for the callback we will
196 // update the stub's function pointer to point at the function
197 // implementation that we just implemented.
198 if (auto Err = IndirectStubsMgr->createStub(
199 mangle(SharedFnAST->getName()), CCAddr, JITSymbolFlags::Exported))
200 return Err;
202 return Error::success();
205 Error executeRemoteExpr(JITTargetAddress ExprAddr) {
206 return Remote.callVoidVoid(ExprAddr);
209 JITSymbol findSymbol(const std::string Name) {
210 return OptimizeLayer.findSymbol(mangle(Name), true);
213 void removeModule(VModuleKey K) {
214 cantFail(OptimizeLayer.removeModule(K));
217 private:
218 std::string mangle(const std::string &Name) {
219 std::string MangledName;
220 raw_string_ostream MangledNameStream(MangledName);
221 Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
222 return MangledNameStream.str();
225 std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
226 // Create a function pass manager.
227 auto FPM = std::make_unique<legacy::FunctionPassManager>(M.get());
229 // Add some optimizations.
230 FPM->add(createInstructionCombiningPass());
231 FPM->add(createReassociatePass());
232 FPM->add(createGVNPass());
233 FPM->add(createCFGSimplificationPass());
234 FPM->doInitialization();
236 // Run the optimizations over all functions in the module being added to
237 // the JIT.
238 for (auto &F : *M)
239 FPM->run(F);
241 return M;
245 } // end namespace orc
246 } // end namespace llvm
248 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H