1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
11 //===----------------------------------------------------------------------===//
13 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
14 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
16 #include "RemoteJITUtils.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/ExecutionEngine/ExecutionEngine.h"
21 #include "llvm/ExecutionEngine/JITSymbol.h"
22 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
26 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
27 #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h"
28 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/LegacyPassManager.h"
31 #include "llvm/IR/Mangler.h"
32 #include "llvm/Support/DynamicLibrary.h"
33 #include "llvm/Support/Error.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Transforms/InstCombine/InstCombine.h"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/Transforms/Scalar/GVN.h"
50 /// FunctionAST - This class represents a function definition itself.
52 std::unique_ptr
<PrototypeAST
> Proto
;
53 std::unique_ptr
<ExprAST
> Body
;
56 FunctionAST(std::unique_ptr
<PrototypeAST
> Proto
,
57 std::unique_ptr
<ExprAST
> Body
)
58 : Proto(std::move(Proto
)), Body(std::move(Body
)) {}
60 const PrototypeAST
& getProto() const;
61 const std::string
& getName() const;
62 llvm::Function
*codegen();
65 /// This will compile FnAST to IR, rename the function to add the given
66 /// suffix (needed to prevent a name-clash with the function's stub),
67 /// and then take ownership of the module that the function was compiled
69 std::unique_ptr
<llvm::Module
>
70 irgenAndTakeOwnership(FunctionAST
&FnAST
, const std::string
&Suffix
);
75 // Typedef the remote-client API.
76 using MyRemote
= remote::OrcRemoteTargetClient
;
78 class KaleidoscopeJIT
{
81 std::shared_ptr
<SymbolResolver
> Resolver
;
82 std::unique_ptr
<TargetMachine
> TM
;
84 LegacyRTDyldObjectLinkingLayer ObjectLayer
;
85 LegacyIRCompileLayer
<decltype(ObjectLayer
), SimpleCompiler
> CompileLayer
;
87 using OptimizeFunction
=
88 std::function
<std::unique_ptr
<Module
>(std::unique_ptr
<Module
>)>;
90 LegacyIRTransformLayer
<decltype(CompileLayer
), OptimizeFunction
> OptimizeLayer
;
92 JITCompileCallbackManager
*CompileCallbackMgr
;
93 std::unique_ptr
<IndirectStubsManager
> IndirectStubsMgr
;
97 KaleidoscopeJIT(ExecutionSession
&ES
, MyRemote
&Remote
)
99 Resolver(createLegacyLookupResolver(
101 [this](const std::string
&Name
) -> JITSymbol
{
102 if (auto Sym
= IndirectStubsMgr
->findStub(Name
, false))
104 if (auto Sym
= OptimizeLayer
.findSymbol(Name
, false))
106 else if (auto Err
= Sym
.takeError())
107 return std::move(Err
);
108 if (auto Addr
= cantFail(this->Remote
.getSymbolAddress(Name
)))
109 return JITSymbol(Addr
, JITSymbolFlags::Exported
);
112 [](Error Err
) { cantFail(std::move(Err
), "lookupFlags failed"); })),
113 TM(EngineBuilder().selectTarget(Triple(Remote
.getTargetTriple()), "",
114 "", SmallVector
<std::string
, 0>())),
115 DL(TM
->createDataLayout()),
116 ObjectLayer(AcknowledgeORCv1Deprecation
, ES
,
117 [this](VModuleKey K
) {
118 return LegacyRTDyldObjectLinkingLayer::Resources
{
119 cantFail(this->Remote
.createRemoteMemoryManager()),
122 CompileLayer(AcknowledgeORCv1Deprecation
, ObjectLayer
,
123 SimpleCompiler(*TM
)),
124 OptimizeLayer(AcknowledgeORCv1Deprecation
, CompileLayer
,
125 [this](std::unique_ptr
<Module
> M
) {
126 return optimizeModule(std::move(M
));
129 auto CCMgrOrErr
= Remote
.enableCompileCallbacks(0);
131 logAllUnhandledErrors(CCMgrOrErr
.takeError(), errs(),
132 "Error enabling remote compile callbacks:");
135 CompileCallbackMgr
= &*CCMgrOrErr
;
136 IndirectStubsMgr
= cantFail(Remote
.createIndirectStubsManager());
137 llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
140 TargetMachine
&getTargetMachine() { return *TM
; }
142 VModuleKey
addModule(std::unique_ptr
<Module
> M
) {
143 // Add the module with a new VModuleKey.
144 auto K
= ES
.allocateVModule();
145 cantFail(OptimizeLayer
.addModule(K
, std::move(M
)));
149 Error
addFunctionAST(std::unique_ptr
<FunctionAST
> FnAST
) {
150 // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support
151 // capture-by-move, which is be required for unique_ptr.
152 auto SharedFnAST
= std::shared_ptr
<FunctionAST
>(std::move(FnAST
));
154 // Set the action to compile our AST. This lambda will be run if/when
155 // execution hits the compile callback (via the stub).
157 // The steps to compile are:
158 // (1) IRGen the function.
159 // (2) Add the IR module to the JIT to make it executable like any other
161 // (3) Use findSymbol to get the address of the compiled function.
162 // (4) Update the stub pointer to point at the implementation so that
163 /// subsequent calls go directly to it and bypass the compiler.
164 // (5) Return the address of the implementation: this lambda will actually
165 // be run inside an attempted call to the function, and we need to
166 // continue on to the implementation to complete the attempted call.
167 // The JIT runtime (the resolver block) will use the return address of
168 // this function as the address to continue at once it has reset the
169 // CPU state to what it was immediately before the call.
170 auto CompileAction
= [this, SharedFnAST
]() {
171 auto M
= irgenAndTakeOwnership(*SharedFnAST
, "$impl");
172 addModule(std::move(M
));
173 auto Sym
= findSymbol(SharedFnAST
->getName() + "$impl");
174 assert(Sym
&& "Couldn't find compiled function?");
175 JITTargetAddress SymAddr
= cantFail(Sym
.getAddress());
176 if (auto Err
= IndirectStubsMgr
->updatePointer(
177 mangle(SharedFnAST
->getName()), SymAddr
)) {
178 logAllUnhandledErrors(std::move(Err
), errs(),
179 "Error updating function pointer: ");
186 // Create a CompileCallback suing the CompileAction - this is the re-entry
187 // point into the compiler for functions that haven't been compiled yet.
188 auto CCAddr
= cantFail(
189 CompileCallbackMgr
->getCompileCallback(std::move(CompileAction
)));
191 // Create an indirect stub. This serves as the functions "canonical
192 // definition" - an unchanging (constant address) entry point to the
193 // function implementation.
194 // Initially we point the stub's function-pointer at the compile callback
195 // that we just created. In the compile action for the callback we will
196 // update the stub's function pointer to point at the function
197 // implementation that we just implemented.
198 if (auto Err
= IndirectStubsMgr
->createStub(
199 mangle(SharedFnAST
->getName()), CCAddr
, JITSymbolFlags::Exported
))
202 return Error::success();
205 Error
executeRemoteExpr(JITTargetAddress ExprAddr
) {
206 return Remote
.callVoidVoid(ExprAddr
);
209 JITSymbol
findSymbol(const std::string Name
) {
210 return OptimizeLayer
.findSymbol(mangle(Name
), true);
213 void removeModule(VModuleKey K
) {
214 cantFail(OptimizeLayer
.removeModule(K
));
218 std::string
mangle(const std::string
&Name
) {
219 std::string MangledName
;
220 raw_string_ostream
MangledNameStream(MangledName
);
221 Mangler::getNameWithPrefix(MangledNameStream
, Name
, DL
);
222 return MangledNameStream
.str();
225 std::unique_ptr
<Module
> optimizeModule(std::unique_ptr
<Module
> M
) {
226 // Create a function pass manager.
227 auto FPM
= std::make_unique
<legacy::FunctionPassManager
>(M
.get());
229 // Add some optimizations.
230 FPM
->add(createInstructionCombiningPass());
231 FPM
->add(createReassociatePass());
232 FPM
->add(createGVNPass());
233 FPM
->add(createCFGSimplificationPass());
234 FPM
->doInitialization();
236 // Run the optimizations over all functions in the module being added to
245 } // end namespace orc
246 } // end namespace llvm
248 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H