1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
12 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/ExecutionEngine.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "mlir/Target/LLVMIR/Export.h"
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/TargetRegistry.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/ToolOutputFile.h"
32 #include "llvm/TargetParser/Host.h"
33 #include "llvm/TargetParser/SubtargetFeature.h"
35 #define DEBUG_TYPE "execution-engine"
42 using llvm::LLVMContext
;
43 using llvm::MemoryBuffer
;
44 using llvm::MemoryBufferRef
;
46 using llvm::SectionMemoryManager
;
47 using llvm::StringError
;
49 using llvm::orc::DynamicLibrarySearchGenerator
;
50 using llvm::orc::ExecutionSession
;
51 using llvm::orc::IRCompileLayer
;
52 using llvm::orc::JITTargetMachineBuilder
;
53 using llvm::orc::MangleAndInterner
;
54 using llvm::orc::RTDyldObjectLinkingLayer
;
55 using llvm::orc::SymbolMap
;
56 using llvm::orc::ThreadSafeModule
;
57 using llvm::orc::TMOwningSimpleCompiler
;
59 /// Wrap a string into an llvm::StringError.
60 static Error
makeStringError(const Twine
&message
) {
61 return llvm::make_error
<StringError
>(message
.str(),
62 llvm::inconvertibleErrorCode());
65 void SimpleObjectCache::notifyObjectCompiled(const Module
*m
,
66 MemoryBufferRef objBuffer
) {
67 cachedObjects
[m
->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68 objBuffer
.getBuffer(), objBuffer
.getBufferIdentifier());
71 std::unique_ptr
<MemoryBuffer
> SimpleObjectCache::getObject(const Module
*m
) {
72 auto i
= cachedObjects
.find(m
->getModuleIdentifier());
73 if (i
== cachedObjects
.end()) {
74 LLVM_DEBUG(dbgs() << "No object for " << m
->getModuleIdentifier()
75 << " in cache. Compiling.\n");
78 LLVM_DEBUG(dbgs() << "Object for " << m
->getModuleIdentifier()
79 << " loaded from cache.\n");
80 return MemoryBuffer::getMemBuffer(i
->second
->getMemBufferRef());
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename
) {
84 // Set up the output file.
85 std::string errorMessage
;
86 auto file
= openOutputFile(outputFilename
, &errorMessage
);
88 llvm::errs() << errorMessage
<< "\n";
92 // Dump the object generated for a single module to the output file.
93 assert(cachedObjects
.size() == 1 && "Expected only one object entry.");
94 auto &cachedObject
= cachedObjects
.begin()->second
;
95 file
->os() << cachedObject
->getBuffer();
99 bool SimpleObjectCache::isEmpty() { return cachedObjects
.empty(); }
101 void ExecutionEngine::dumpToObjectFile(StringRef filename
) {
102 if (cache
== nullptr) {
103 llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104 "object cache is disabled\n";
107 // Compilation is lazy and it doesn't populate object cache unless requested.
108 // In case object dump is requested before cache is populated, we need to
109 // force compilation manually.
110 if (cache
->isEmpty()) {
111 for (std::string
&functionName
: functionNames
) {
112 auto result
= lookupPacked(functionName
);
114 llvm::errs() << "Could not compile " << functionName
<< ":\n "
115 << result
.takeError() << "\n";
120 cache
->dumpToObjectFile(filename
);
123 void ExecutionEngine::registerSymbols(
124 llvm::function_ref
<SymbolMap(MangleAndInterner
)> symbolMap
) {
125 auto &mainJitDylib
= jit
->getMainJITDylib();
126 cantFail(mainJitDylib
.define(
127 absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128 mainJitDylib
.getExecutionSession(), jit
->getDataLayout())))));
131 void ExecutionEngine::setupTargetTripleAndDataLayout(Module
*llvmModule
,
132 llvm::TargetMachine
*tm
) {
133 llvmModule
->setDataLayout(tm
->createDataLayout());
134 llvmModule
->setTargetTriple(tm
->getTargetTriple().getTriple());
137 static std::string
makePackedFunctionName(StringRef name
) {
138 return "_mlir_" + name
.str();
141 // For each function in the LLVM module, define an interface function that wraps
142 // all the arguments of the original function and all its results into an i8**
143 // pointer to provide a unified invocation interface.
144 static void packFunctionArguments(Module
*module
) {
145 auto &ctx
= module
->getContext();
146 llvm::IRBuilder
<> builder(ctx
);
147 DenseSet
<llvm::Function
*> interfaceFunctions
;
148 for (auto &func
: module
->getFunctionList()) {
149 if (func
.isDeclaration()) {
152 if (interfaceFunctions
.count(&func
)) {
156 // Given a function `foo(<...>)`, define the interface function
159 llvm::FunctionType::get(builder
.getVoidTy(), builder
.getPtrTy(),
161 auto newName
= makePackedFunctionName(func
.getName());
162 auto funcCst
= module
->getOrInsertFunction(newName
, newType
);
163 llvm::Function
*interfaceFunc
= cast
<llvm::Function
>(funcCst
.getCallee());
164 interfaceFunctions
.insert(interfaceFunc
);
166 // Extract the arguments from the type-erased argument list and cast them to
168 auto *bb
= llvm::BasicBlock::Create(ctx
);
169 bb
->insertInto(interfaceFunc
);
170 builder
.SetInsertPoint(bb
);
171 llvm::Value
*argList
= interfaceFunc
->arg_begin();
172 SmallVector
<llvm::Value
*, 8> args
;
173 args
.reserve(llvm::size(func
.args()));
174 for (auto [index
, arg
] : llvm::enumerate(func
.args())) {
175 llvm::Value
*argIndex
= llvm::Constant::getIntegerValue(
176 builder
.getInt64Ty(), APInt(64, index
));
177 llvm::Value
*argPtrPtr
=
178 builder
.CreateGEP(builder
.getPtrTy(), argList
, argIndex
);
179 llvm::Value
*argPtr
= builder
.CreateLoad(builder
.getPtrTy(), argPtrPtr
);
180 llvm::Type
*argTy
= arg
.getType();
181 llvm::Value
*load
= builder
.CreateLoad(argTy
, argPtr
);
182 args
.push_back(load
);
185 // Call the implementation function with the extracted arguments.
186 llvm::Value
*result
= builder
.CreateCall(&func
, args
);
188 // Assuming the result is one value, potentially of type `void`.
189 if (!result
->getType()->isVoidTy()) {
190 llvm::Value
*retIndex
= llvm::Constant::getIntegerValue(
191 builder
.getInt64Ty(), APInt(64, llvm::size(func
.args())));
192 llvm::Value
*retPtrPtr
=
193 builder
.CreateGEP(builder
.getPtrTy(), argList
, retIndex
);
194 llvm::Value
*retPtr
= builder
.CreateLoad(builder
.getPtrTy(), retPtrPtr
);
195 builder
.CreateStore(result
, retPtr
);
198 // The interface function returns void.
199 builder
.CreateRetVoid();
203 ExecutionEngine::ExecutionEngine(bool enableObjectDump
,
204 bool enableGDBNotificationListener
,
205 bool enablePerfNotificationListener
)
206 : cache(enableObjectDump
? new SimpleObjectCache() : nullptr),
208 gdbListener(enableGDBNotificationListener
209 ? llvm::JITEventListener::createGDBRegistrationListener()
211 perfListener(nullptr) {
212 if (enablePerfNotificationListener
) {
213 if (auto *listener
= llvm::JITEventListener::createPerfJITEventListener())
214 perfListener
= listener
;
215 else if (auto *listener
=
216 llvm::JITEventListener::createIntelJITEventListener())
217 perfListener
= listener
;
221 ExecutionEngine::~ExecutionEngine() {
222 // Execute the global destructors from the module being processed.
223 // TODO: Allow JIT deinitialize for AArch64. Currently there's a bug causing a
224 // crash for AArch64 see related issue #71963.
225 if (jit
&& !jit
->getTargetTriple().isAArch64())
226 llvm::consumeError(jit
->deinitialize(jit
->getMainJITDylib()));
227 // Run all dynamic library destroy callbacks to prepare for the shutdown.
228 for (LibraryDestroyFn destroy
: destroyFns
)
232 Expected
<std::unique_ptr
<ExecutionEngine
>>
233 ExecutionEngine::create(Operation
*m
, const ExecutionEngineOptions
&options
,
234 std::unique_ptr
<llvm::TargetMachine
> tm
) {
235 auto engine
= std::make_unique
<ExecutionEngine
>(
236 options
.enableObjectDump
, options
.enableGDBNotificationListener
,
237 options
.enablePerfNotificationListener
);
239 // Remember all entry-points if object dumping is enabled.
240 if (options
.enableObjectDump
) {
241 for (auto funcOp
: m
->getRegion(0).getOps
<LLVM::LLVMFuncOp
>()) {
242 StringRef funcName
= funcOp
.getSymName();
243 engine
->functionNames
.push_back(funcName
.str());
247 std::unique_ptr
<llvm::LLVMContext
> ctx(new llvm::LLVMContext
);
248 auto llvmModule
= options
.llvmModuleBuilder
249 ? options
.llvmModuleBuilder(m
, *ctx
)
250 : translateModuleToLLVMIR(m
, *ctx
);
252 return makeStringError("could not convert to LLVM IR");
254 // If no valid TargetMachine was passed, create a default TM ignoring any
255 // input arguments from the user.
257 auto tmBuilderOrError
= llvm::orc::JITTargetMachineBuilder::detectHost();
258 if (!tmBuilderOrError
)
259 return tmBuilderOrError
.takeError();
261 auto tmOrError
= tmBuilderOrError
->createTargetMachine();
263 return tmOrError
.takeError();
264 tm
= std::move(tmOrError
.get());
267 // TODO: Currently, the LLVM module created above has no triple associated
268 // with it. Instead, the triple is extracted from the TargetMachine, which is
269 // either based on the host defaults or command line arguments when specified
270 // (set-up by callers of this method). It could also be passed to the
271 // translation or dialect conversion instead of this.
272 setupTargetTripleAndDataLayout(llvmModule
.get(), tm
.get());
273 packFunctionArguments(llvmModule
.get());
275 auto dataLayout
= llvmModule
->getDataLayout();
277 // Use absolute library path so that gdb can find the symbol table.
278 SmallVector
<SmallString
<256>, 4> sharedLibPaths
;
280 options
.sharedLibPaths
, std::back_inserter(sharedLibPaths
),
281 [](StringRef libPath
) {
282 SmallString
<256> absPath(libPath
.begin(), libPath
.end());
283 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath
)));
287 // If shared library implements custom execution layer library init and
288 // destroy functions, we'll use them to register the library. Otherwise, load
289 // the library as JITDyLib below.
290 llvm::StringMap
<void *> exportSymbols
;
291 SmallVector
<LibraryDestroyFn
> destroyFns
;
292 SmallVector
<StringRef
> jitDyLibPaths
;
294 for (auto &libPath
: sharedLibPaths
) {
295 auto lib
= llvm::sys::DynamicLibrary::getPermanentLibrary(
296 libPath
.str().str().c_str());
297 void *initSym
= lib
.getAddressOfSymbol(kLibraryInitFnName
);
298 void *destroySim
= lib
.getAddressOfSymbol(kLibraryDestroyFnName
);
300 // Library does not provide call backs, rely on symbol visiblity.
301 if (!initSym
|| !destroySim
) {
302 jitDyLibPaths
.push_back(libPath
);
306 auto initFn
= reinterpret_cast<LibraryInitFn
>(initSym
);
307 initFn(exportSymbols
);
309 auto destroyFn
= reinterpret_cast<LibraryDestroyFn
>(destroySim
);
310 destroyFns
.push_back(destroyFn
);
312 engine
->destroyFns
= std::move(destroyFns
);
314 // Callback to create the object layer with symbol resolution to current
315 // process and dynamically linked libraries.
316 auto objectLinkingLayerCreator
= [&](ExecutionSession
&session
,
318 auto objectLayer
= std::make_unique
<RTDyldObjectLinkingLayer
>(
319 session
, [sectionMemoryMapper
= options
.sectionMemoryMapper
]() {
320 return std::make_unique
<SectionMemoryManager
>(sectionMemoryMapper
);
323 // Register JIT event listeners if they are enabled.
324 if (engine
->gdbListener
)
325 objectLayer
->registerJITEventListener(*engine
->gdbListener
);
326 if (engine
->perfListener
)
327 objectLayer
->registerJITEventListener(*engine
->perfListener
);
329 // COFF format binaries (Windows) need special handling to deal with
330 // exported symbol visibility.
331 // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
332 llvm::Triple
targetTriple(llvm::Twine(llvmModule
->getTargetTriple()));
333 if (targetTriple
.isOSBinFormatCOFF()) {
334 objectLayer
->setOverrideObjectFlagsWithResponsibilityFlags(true);
335 objectLayer
->setAutoClaimResponsibilityForObjectSymbols(true);
338 // Resolve symbols from shared libraries.
339 for (auto &libPath
: jitDyLibPaths
) {
340 auto mb
= llvm::MemoryBuffer::getFile(libPath
);
342 errs() << "Failed to create MemoryBuffer for: " << libPath
343 << "\nError: " << mb
.getError().message() << "\n";
346 auto &jd
= session
.createBareJITDylib(std::string(libPath
));
347 auto loaded
= DynamicLibrarySearchGenerator::Load(
348 libPath
.str().c_str(), dataLayout
.getGlobalPrefix());
350 errs() << "Could not load " << libPath
<< ":\n " << loaded
.takeError()
354 jd
.addGenerator(std::move(*loaded
));
355 cantFail(objectLayer
->add(jd
, std::move(mb
.get())));
361 // Callback to inspect the cache and recompile on demand. This follows Lang's
362 // LLJITWithObjectCache example.
363 auto compileFunctionCreator
= [&](JITTargetMachineBuilder jtmb
)
364 -> Expected
<std::unique_ptr
<IRCompileLayer::IRCompiler
>> {
365 if (options
.jitCodeGenOptLevel
)
366 jtmb
.setCodeGenOptLevel(*options
.jitCodeGenOptLevel
);
367 return std::make_unique
<TMOwningSimpleCompiler
>(std::move(tm
),
368 engine
->cache
.get());
371 // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
373 cantFail(llvm::orc::LLJITBuilder()
374 .setCompileFunctionCreator(compileFunctionCreator
)
375 .setObjectLinkingLayerCreator(objectLinkingLayerCreator
)
376 .setDataLayout(dataLayout
)
379 // Add a ThreadSafemodule to the engine and return.
380 ThreadSafeModule
tsm(std::move(llvmModule
), std::move(ctx
));
381 if (options
.transformer
)
382 cantFail(tsm
.withModuleDo(
383 [&](llvm::Module
&module
) { return options
.transformer(&module
); }));
384 cantFail(jit
->addIRModule(std::move(tsm
)));
385 engine
->jit
= std::move(jit
);
387 // Resolve symbols that are statically linked in the current process.
388 llvm::orc::JITDylib
&mainJD
= engine
->jit
->getMainJITDylib();
390 cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
391 dataLayout
.getGlobalPrefix())));
393 // Build a runtime symbol map from the exported symbols and register them.
394 auto runtimeSymbolMap
= [&](llvm::orc::MangleAndInterner interner
) {
395 auto symbolMap
= llvm::orc::SymbolMap();
396 for (auto &exportSymbol
: exportSymbols
)
397 symbolMap
[interner(exportSymbol
.getKey())] = {
398 llvm::orc::ExecutorAddr::fromPtr(exportSymbol
.getValue()),
399 llvm::JITSymbolFlags::Exported
};
402 engine
->registerSymbols(runtimeSymbolMap
);
404 // Execute the global constructors from the module being processed.
405 // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
406 // crash for AArch64 see related issue #71963.
407 if (!engine
->jit
->getTargetTriple().isAArch64())
408 cantFail(engine
->jit
->initialize(engine
->jit
->getMainJITDylib()));
410 return std::move(engine
);
413 Expected
<void (*)(void **)>
414 ExecutionEngine::lookupPacked(StringRef name
) const {
415 auto result
= lookup(makePackedFunctionName(name
));
417 return result
.takeError();
418 return reinterpret_cast<void (*)(void **)>(result
.get());
421 Expected
<void *> ExecutionEngine::lookup(StringRef name
) const {
422 auto expectedSymbol
= jit
->lookup(name
);
424 // JIT lookup may return an Error referring to strings stored internally by
425 // the JIT. If the Error outlives the ExecutionEngine, it would want have a
426 // dangling reference, which is currently caught by an assertion inside JIT
427 // thanks to hand-rolled reference counting. Rewrap the error message into a
428 // string before returning. Alternatively, ORC JIT should consider copying
429 // the string into the error message.
430 if (!expectedSymbol
) {
431 std::string errorMessage
;
432 llvm::raw_string_ostream
os(errorMessage
);
433 llvm::handleAllErrors(expectedSymbol
.takeError(),
434 [&os
](llvm::ErrorInfoBase
&ei
) { ei
.log(os
); });
435 return makeStringError(os
.str());
438 if (void *fptr
= expectedSymbol
->toPtr
<void *>())
440 return makeStringError("looked up function is null");
443 Error
ExecutionEngine::invokePacked(StringRef name
,
444 MutableArrayRef
<void *> args
) {
445 auto expectedFPtr
= lookupPacked(name
);
447 return expectedFPtr
.takeError();
448 auto fptr
= *expectedFPtr
;
450 (*fptr
)(args
.data());
452 return Error::success();