[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / ExecutionEngine / ExecutionEngine.cpp
blobdbcc0ba6fc99c679471bb5abbb6522d1206fd823
1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the execution engine for MLIR modules based on LLVM Orc
10 // JIT engine.
12 //===----------------------------------------------------------------------===//
13 #include "mlir/ExecutionEngine/ExecutionEngine.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Support/FileUtilities.h"
17 #include "mlir/Target/LLVMIR/Export.h"
19 #include "llvm/ExecutionEngine/JITEventListener.h"
20 #include "llvm/ExecutionEngine/ObjectCache.h"
21 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
22 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
23 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
24 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
25 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
26 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/MC/TargetRegistry.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/Error.h"
31 #include "llvm/Support/ToolOutputFile.h"
32 #include "llvm/TargetParser/Host.h"
33 #include "llvm/TargetParser/SubtargetFeature.h"
35 #define DEBUG_TYPE "execution-engine"
37 using namespace mlir;
38 using llvm::dbgs;
39 using llvm::Error;
40 using llvm::errs;
41 using llvm::Expected;
42 using llvm::LLVMContext;
43 using llvm::MemoryBuffer;
44 using llvm::MemoryBufferRef;
45 using llvm::Module;
46 using llvm::SectionMemoryManager;
47 using llvm::StringError;
48 using llvm::Triple;
49 using llvm::orc::DynamicLibrarySearchGenerator;
50 using llvm::orc::ExecutionSession;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::MangleAndInterner;
54 using llvm::orc::RTDyldObjectLinkingLayer;
55 using llvm::orc::SymbolMap;
56 using llvm::orc::ThreadSafeModule;
57 using llvm::orc::TMOwningSimpleCompiler;
59 /// Wrap a string into an llvm::StringError.
60 static Error makeStringError(const Twine &message) {
61 return llvm::make_error<StringError>(message.str(),
62 llvm::inconvertibleErrorCode());
65 void SimpleObjectCache::notifyObjectCompiled(const Module *m,
66 MemoryBufferRef objBuffer) {
67 cachedObjects[m->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
68 objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
71 std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *m) {
72 auto i = cachedObjects.find(m->getModuleIdentifier());
73 if (i == cachedObjects.end()) {
74 LLVM_DEBUG(dbgs() << "No object for " << m->getModuleIdentifier()
75 << " in cache. Compiling.\n");
76 return nullptr;
78 LLVM_DEBUG(dbgs() << "Object for " << m->getModuleIdentifier()
79 << " loaded from cache.\n");
80 return MemoryBuffer::getMemBuffer(i->second->getMemBufferRef());
83 void SimpleObjectCache::dumpToObjectFile(StringRef outputFilename) {
84 // Set up the output file.
85 std::string errorMessage;
86 auto file = openOutputFile(outputFilename, &errorMessage);
87 if (!file) {
88 llvm::errs() << errorMessage << "\n";
89 return;
92 // Dump the object generated for a single module to the output file.
93 assert(cachedObjects.size() == 1 && "Expected only one object entry.");
94 auto &cachedObject = cachedObjects.begin()->second;
95 file->os() << cachedObject->getBuffer();
96 file->keep();
99 bool SimpleObjectCache::isEmpty() { return cachedObjects.empty(); }
101 void ExecutionEngine::dumpToObjectFile(StringRef filename) {
102 if (cache == nullptr) {
103 llvm::errs() << "cannot dump ExecutionEngine object code to file: "
104 "object cache is disabled\n";
105 return;
107 // Compilation is lazy and it doesn't populate object cache unless requested.
108 // In case object dump is requested before cache is populated, we need to
109 // force compilation manually.
110 if (cache->isEmpty()) {
111 for (std::string &functionName : functionNames) {
112 auto result = lookupPacked(functionName);
113 if (!result) {
114 llvm::errs() << "Could not compile " << functionName << ":\n "
115 << result.takeError() << "\n";
116 return;
120 cache->dumpToObjectFile(filename);
123 void ExecutionEngine::registerSymbols(
124 llvm::function_ref<SymbolMap(MangleAndInterner)> symbolMap) {
125 auto &mainJitDylib = jit->getMainJITDylib();
126 cantFail(mainJitDylib.define(
127 absoluteSymbols(symbolMap(llvm::orc::MangleAndInterner(
128 mainJitDylib.getExecutionSession(), jit->getDataLayout())))));
131 void ExecutionEngine::setupTargetTripleAndDataLayout(Module *llvmModule,
132 llvm::TargetMachine *tm) {
133 llvmModule->setDataLayout(tm->createDataLayout());
134 llvmModule->setTargetTriple(tm->getTargetTriple().getTriple());
137 static std::string makePackedFunctionName(StringRef name) {
138 return "_mlir_" + name.str();
141 // For each function in the LLVM module, define an interface function that wraps
142 // all the arguments of the original function and all its results into an i8**
143 // pointer to provide a unified invocation interface.
144 static void packFunctionArguments(Module *module) {
145 auto &ctx = module->getContext();
146 llvm::IRBuilder<> builder(ctx);
147 DenseSet<llvm::Function *> interfaceFunctions;
148 for (auto &func : module->getFunctionList()) {
149 if (func.isDeclaration()) {
150 continue;
152 if (interfaceFunctions.count(&func)) {
153 continue;
156 // Given a function `foo(<...>)`, define the interface function
157 // `mlir_foo(i8**)`.
158 auto *newType =
159 llvm::FunctionType::get(builder.getVoidTy(), builder.getPtrTy(),
160 /*isVarArg=*/false);
161 auto newName = makePackedFunctionName(func.getName());
162 auto funcCst = module->getOrInsertFunction(newName, newType);
163 llvm::Function *interfaceFunc = cast<llvm::Function>(funcCst.getCallee());
164 interfaceFunctions.insert(interfaceFunc);
166 // Extract the arguments from the type-erased argument list and cast them to
167 // the proper types.
168 auto *bb = llvm::BasicBlock::Create(ctx);
169 bb->insertInto(interfaceFunc);
170 builder.SetInsertPoint(bb);
171 llvm::Value *argList = interfaceFunc->arg_begin();
172 SmallVector<llvm::Value *, 8> args;
173 args.reserve(llvm::size(func.args()));
174 for (auto [index, arg] : llvm::enumerate(func.args())) {
175 llvm::Value *argIndex = llvm::Constant::getIntegerValue(
176 builder.getInt64Ty(), APInt(64, index));
177 llvm::Value *argPtrPtr =
178 builder.CreateGEP(builder.getInt8PtrTy(), argList, argIndex);
179 llvm::Value *argPtr =
180 builder.CreateLoad(builder.getInt8PtrTy(), argPtrPtr);
181 llvm::Type *argTy = arg.getType();
182 llvm::Value *load = builder.CreateLoad(argTy, argPtr);
183 args.push_back(load);
186 // Call the implementation function with the extracted arguments.
187 llvm::Value *result = builder.CreateCall(&func, args);
189 // Assuming the result is one value, potentially of type `void`.
190 if (!result->getType()->isVoidTy()) {
191 llvm::Value *retIndex = llvm::Constant::getIntegerValue(
192 builder.getInt64Ty(), APInt(64, llvm::size(func.args())));
193 llvm::Value *retPtrPtr =
194 builder.CreateGEP(builder.getInt8PtrTy(), argList, retIndex);
195 llvm::Value *retPtr =
196 builder.CreateLoad(builder.getInt8PtrTy(), retPtrPtr);
197 builder.CreateStore(result, retPtr);
200 // The interface function returns void.
201 builder.CreateRetVoid();
205 ExecutionEngine::ExecutionEngine(bool enableObjectDump,
206 bool enableGDBNotificationListener,
207 bool enablePerfNotificationListener)
208 : cache(enableObjectDump ? new SimpleObjectCache() : nullptr),
209 functionNames(),
210 gdbListener(enableGDBNotificationListener
211 ? llvm::JITEventListener::createGDBRegistrationListener()
212 : nullptr),
213 perfListener(nullptr) {
214 if (enablePerfNotificationListener) {
215 if (auto *listener = llvm::JITEventListener::createPerfJITEventListener())
216 perfListener = listener;
217 else if (auto *listener =
218 llvm::JITEventListener::createIntelJITEventListener())
219 perfListener = listener;
223 ExecutionEngine::~ExecutionEngine() {
224 // Run all dynamic library destroy callbacks to prepare for the shutdown.
225 for (LibraryDestroyFn destroy : destroyFns)
226 destroy();
229 Expected<std::unique_ptr<ExecutionEngine>>
230 ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
231 std::unique_ptr<llvm::TargetMachine> tm) {
232 auto engine = std::make_unique<ExecutionEngine>(
233 options.enableObjectDump, options.enableGDBNotificationListener,
234 options.enablePerfNotificationListener);
236 // Remember all entry-points if object dumping is enabled.
237 if (options.enableObjectDump) {
238 for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
239 StringRef funcName = funcOp.getSymName();
240 engine->functionNames.push_back(funcName.str());
244 std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
245 auto llvmModule = options.llvmModuleBuilder
246 ? options.llvmModuleBuilder(m, *ctx)
247 : translateModuleToLLVMIR(m, *ctx);
248 if (!llvmModule)
249 return makeStringError("could not convert to LLVM IR");
251 // If no valid TargetMachine was passed, create a default TM ignoring any
252 // input arguments from the user.
253 if (!tm) {
254 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
255 if (!tmBuilderOrError)
256 return tmBuilderOrError.takeError();
258 auto tmOrError = tmBuilderOrError->createTargetMachine();
259 if (!tmOrError)
260 return tmOrError.takeError();
261 tm = std::move(tmOrError.get());
264 // TODO: Currently, the LLVM module created above has no triple associated
265 // with it. Instead, the triple is extracted from the TargetMachine, which is
266 // either based on the host defaults or command line arguments when specified
267 // (set-up by callers of this method). It could also be passed to the
268 // translation or dialect conversion instead of this.
269 setupTargetTripleAndDataLayout(llvmModule.get(), tm.get());
270 packFunctionArguments(llvmModule.get());
272 auto dataLayout = llvmModule->getDataLayout();
274 // Use absolute library path so that gdb can find the symbol table.
275 SmallVector<SmallString<256>, 4> sharedLibPaths;
276 transform(
277 options.sharedLibPaths, std::back_inserter(sharedLibPaths),
278 [](StringRef libPath) {
279 SmallString<256> absPath(libPath.begin(), libPath.end());
280 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
281 return absPath;
284 // If shared library implements custom execution layer library init and
285 // destroy functions, we'll use them to register the library. Otherwise, load
286 // the library as JITDyLib below.
287 llvm::StringMap<void *> exportSymbols;
288 SmallVector<LibraryDestroyFn> destroyFns;
289 SmallVector<StringRef> jitDyLibPaths;
291 for (auto &libPath : sharedLibPaths) {
292 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(
293 libPath.str().str().c_str());
294 void *initSym = lib.getAddressOfSymbol(kLibraryInitFnName);
295 void *destroySim = lib.getAddressOfSymbol(kLibraryDestroyFnName);
297 // Library does not provide call backs, rely on symbol visiblity.
298 if (!initSym || !destroySim) {
299 jitDyLibPaths.push_back(libPath);
300 continue;
303 auto initFn = reinterpret_cast<LibraryInitFn>(initSym);
304 initFn(exportSymbols);
306 auto destroyFn = reinterpret_cast<LibraryDestroyFn>(destroySim);
307 destroyFns.push_back(destroyFn);
309 engine->destroyFns = std::move(destroyFns);
311 // Callback to create the object layer with symbol resolution to current
312 // process and dynamically linked libraries.
313 auto objectLinkingLayerCreator = [&](ExecutionSession &session,
314 const Triple &tt) {
315 auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
316 session, [sectionMemoryMapper = options.sectionMemoryMapper]() {
317 return std::make_unique<SectionMemoryManager>(sectionMemoryMapper);
320 // Register JIT event listeners if they are enabled.
321 if (engine->gdbListener)
322 objectLayer->registerJITEventListener(*engine->gdbListener);
323 if (engine->perfListener)
324 objectLayer->registerJITEventListener(*engine->perfListener);
326 // COFF format binaries (Windows) need special handling to deal with
327 // exported symbol visibility.
328 // cf llvm/lib/ExecutionEngine/Orc/LLJIT.cpp LLJIT::createObjectLinkingLayer
329 llvm::Triple targetTriple(llvm::Twine(llvmModule->getTargetTriple()));
330 if (targetTriple.isOSBinFormatCOFF()) {
331 objectLayer->setOverrideObjectFlagsWithResponsibilityFlags(true);
332 objectLayer->setAutoClaimResponsibilityForObjectSymbols(true);
335 // Resolve symbols from shared libraries.
336 for (auto &libPath : jitDyLibPaths) {
337 auto mb = llvm::MemoryBuffer::getFile(libPath);
338 if (!mb) {
339 errs() << "Failed to create MemoryBuffer for: " << libPath
340 << "\nError: " << mb.getError().message() << "\n";
341 continue;
343 auto &jd = session.createBareJITDylib(std::string(libPath));
344 auto loaded = DynamicLibrarySearchGenerator::Load(
345 libPath.str().c_str(), dataLayout.getGlobalPrefix());
346 if (!loaded) {
347 errs() << "Could not load " << libPath << ":\n " << loaded.takeError()
348 << "\n";
349 continue;
351 jd.addGenerator(std::move(*loaded));
352 cantFail(objectLayer->add(jd, std::move(mb.get())));
355 return objectLayer;
358 // Callback to inspect the cache and recompile on demand. This follows Lang's
359 // LLJITWithObjectCache example.
360 auto compileFunctionCreator = [&](JITTargetMachineBuilder jtmb)
361 -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
362 if (options.jitCodeGenOptLevel)
363 jtmb.setCodeGenOptLevel(*options.jitCodeGenOptLevel);
364 return std::make_unique<TMOwningSimpleCompiler>(std::move(tm),
365 engine->cache.get());
368 // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
369 auto jit =
370 cantFail(llvm::orc::LLJITBuilder()
371 .setCompileFunctionCreator(compileFunctionCreator)
372 .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
373 .setDataLayout(dataLayout)
374 .create());
376 // Add a ThreadSafemodule to the engine and return.
377 ThreadSafeModule tsm(std::move(llvmModule), std::move(ctx));
378 if (options.transformer)
379 cantFail(tsm.withModuleDo(
380 [&](llvm::Module &module) { return options.transformer(&module); }));
381 cantFail(jit->addIRModule(std::move(tsm)));
382 engine->jit = std::move(jit);
384 // Resolve symbols that are statically linked in the current process.
385 llvm::orc::JITDylib &mainJD = engine->jit->getMainJITDylib();
386 mainJD.addGenerator(
387 cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
388 dataLayout.getGlobalPrefix())));
390 // Build a runtime symbol map from the exported symbols and register them.
391 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
392 auto symbolMap = llvm::orc::SymbolMap();
393 for (auto &exportSymbol : exportSymbols)
394 symbolMap[interner(exportSymbol.getKey())] = {
395 llvm::orc::ExecutorAddr::fromPtr(exportSymbol.getValue()),
396 llvm::JITSymbolFlags::Exported};
397 return symbolMap;
399 engine->registerSymbols(runtimeSymbolMap);
401 return std::move(engine);
404 Expected<void (*)(void **)>
405 ExecutionEngine::lookupPacked(StringRef name) const {
406 auto result = lookup(makePackedFunctionName(name));
407 if (!result)
408 return result.takeError();
409 return reinterpret_cast<void (*)(void **)>(result.get());
412 Expected<void *> ExecutionEngine::lookup(StringRef name) const {
413 auto expectedSymbol = jit->lookup(name);
415 // JIT lookup may return an Error referring to strings stored internally by
416 // the JIT. If the Error outlives the ExecutionEngine, it would want have a
417 // dangling reference, which is currently caught by an assertion inside JIT
418 // thanks to hand-rolled reference counting. Rewrap the error message into a
419 // string before returning. Alternatively, ORC JIT should consider copying
420 // the string into the error message.
421 if (!expectedSymbol) {
422 std::string errorMessage;
423 llvm::raw_string_ostream os(errorMessage);
424 llvm::handleAllErrors(expectedSymbol.takeError(),
425 [&os](llvm::ErrorInfoBase &ei) { ei.log(os); });
426 return makeStringError(os.str());
429 if (void *fptr = expectedSymbol->toPtr<void *>())
430 return fptr;
431 return makeStringError("looked up function is null");
434 Error ExecutionEngine::invokePacked(StringRef name,
435 MutableArrayRef<void *> args) {
436 auto expectedFPtr = lookupPacked(name);
437 if (!expectedFPtr)
438 return expectedFPtr.takeError();
439 auto fptr = *expectedFPtr;
441 (*fptr)(args.data());
443 return Error::success();