1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
13 // The translation can be customized by providing an MLIR to MLIR
15 //===----------------------------------------------------------------------===//
17 #include "mlir/ExecutionEngine/JitRunner.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 #include "mlir/Tools/ParseUtilities.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
45 #define DEBUG_TYPE "jit-runner"
51 /// This options struct prevents the need for global static initializers, and
52 /// is only initialized if the JITRunner is invoked.
54 llvm::cl::opt
<std::string
> inputFilename
{llvm::cl::Positional
,
55 llvm::cl::desc("<input file>"),
57 llvm::cl::opt
<std::string
> mainFuncName
{
58 "e", llvm::cl::desc("The function to be called"),
59 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
60 llvm::cl::opt
<std::string
> mainFuncType
{
62 llvm::cl::desc("Textual description of the function type to be called"),
63 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
65 llvm::cl::OptionCategory optFlags
{"opt-like flags"};
67 // CLI variables for -On options.
68 llvm::cl::opt
<bool> optO0
{"O0",
69 llvm::cl::desc("Run opt passes and codegen at O0"),
70 llvm::cl::cat(optFlags
)};
71 llvm::cl::opt
<bool> optO1
{"O1",
72 llvm::cl::desc("Run opt passes and codegen at O1"),
73 llvm::cl::cat(optFlags
)};
74 llvm::cl::opt
<bool> optO2
{"O2",
75 llvm::cl::desc("Run opt passes and codegen at O2"),
76 llvm::cl::cat(optFlags
)};
77 llvm::cl::opt
<bool> optO3
{"O3",
78 llvm::cl::desc("Run opt passes and codegen at O3"),
79 llvm::cl::cat(optFlags
)};
81 llvm::cl::list
<std::string
> mAttrs
{
82 "mattr", llvm::cl::MiscFlags::CommaSeparated
,
83 llvm::cl::desc("Target specific attributes (-mattr=help for details)"),
84 llvm::cl::value_desc("a1,+a2,-a3,..."), llvm::cl::cat(optFlags
)};
86 llvm::cl::opt
<std::string
> mArch
{
88 llvm::cl::desc("Architecture to generate code for (see --version)")};
90 llvm::cl::OptionCategory clOptionsCategory
{"linking options"};
91 llvm::cl::list
<std::string
> clSharedLibs
{
92 "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
93 llvm::cl::MiscFlags::CommaSeparated
, llvm::cl::cat(clOptionsCategory
)};
95 /// CLI variables for debugging.
96 llvm::cl::opt
<bool> dumpObjectFile
{
98 llvm::cl::desc("Dump JITted-compiled object to file specified with "
99 "-object-filename (<input file>.o by default).")};
101 llvm::cl::opt
<std::string
> objectFilename
{
103 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
105 llvm::cl::opt
<bool> hostSupportsJit
{"host-supports-jit",
106 llvm::cl::desc("Report host JIT support"),
109 llvm::cl::opt
<bool> noImplicitModule
{
110 "no-implicit-module",
112 "Disable implicit addition of a top-level module op during parsing"),
113 llvm::cl::init(false)};
116 struct CompileAndExecuteConfig
{
117 /// LLVM module transformer that is passed to ExecutionEngine.
118 std::function
<llvm::Error(llvm::Module
*)> transformer
;
120 /// A custom function that is passed to ExecutionEngine. It processes MLIR
121 /// module and creates LLVM IR module.
122 llvm::function_ref
<std::unique_ptr
<llvm::Module
>(Operation
*,
123 llvm::LLVMContext
&)>
126 /// A custom function that is passed to ExecutinEngine to register symbols at
128 llvm::function_ref
<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner
)>
134 static OwningOpRef
<Operation
*> parseMLIRInput(StringRef inputFilename
,
135 bool insertImplicitModule
,
136 MLIRContext
*context
) {
137 // Set up the input file.
138 std::string errorMessage
;
139 auto file
= openInputFile(inputFilename
, &errorMessage
);
141 llvm::errs() << errorMessage
<< "\n";
145 auto sourceMgr
= std::make_shared
<llvm::SourceMgr
>();
146 sourceMgr
->AddNewSourceBuffer(std::move(file
), SMLoc());
147 OwningOpRef
<Operation
*> module
=
148 parseSourceFileForTool(sourceMgr
, context
, insertImplicitModule
);
151 if (!module
.get()->hasTrait
<OpTrait::SymbolTable
>()) {
152 llvm::errs() << "Error: top-level op must be a symbol table.\n";
158 static inline Error
makeStringError(const Twine
&message
) {
159 return llvm::make_error
<llvm::StringError
>(message
.str(),
160 llvm::inconvertibleErrorCode());
163 static std::optional
<unsigned> getCommandLineOptLevel(Options
&options
) {
164 std::optional
<unsigned> optLevel
;
165 SmallVector
<std::reference_wrapper
<llvm::cl::opt
<bool>>, 4> optFlags
{
166 options
.optO0
, options
.optO1
, options
.optO2
, options
.optO3
};
168 // Determine if there is an optimization flag present.
169 for (unsigned j
= 0; j
< 4; ++j
) {
170 auto &flag
= optFlags
[j
].get();
179 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
181 compileAndExecute(Options
&options
, Operation
*module
, StringRef entryPoint
,
182 CompileAndExecuteConfig config
, void **args
,
183 std::unique_ptr
<llvm::TargetMachine
> tm
= nullptr) {
184 std::optional
<llvm::CodeGenOptLevel
> jitCodeGenOptLevel
;
185 if (auto clOptLevel
= getCommandLineOptLevel(options
))
186 jitCodeGenOptLevel
= static_cast<llvm::CodeGenOptLevel
>(*clOptLevel
);
188 SmallVector
<StringRef
, 4> sharedLibs(options
.clSharedLibs
.begin(),
189 options
.clSharedLibs
.end());
191 mlir::ExecutionEngineOptions engineOptions
;
192 engineOptions
.llvmModuleBuilder
= config
.llvmModuleBuilder
;
193 if (config
.transformer
)
194 engineOptions
.transformer
= config
.transformer
;
195 engineOptions
.jitCodeGenOptLevel
= jitCodeGenOptLevel
;
196 engineOptions
.sharedLibPaths
= sharedLibs
;
197 engineOptions
.enableObjectDump
= true;
198 auto expectedEngine
=
199 mlir::ExecutionEngine::create(module
, engineOptions
, std::move(tm
));
201 return expectedEngine
.takeError();
203 auto engine
= std::move(*expectedEngine
);
205 auto expectedFPtr
= engine
->lookupPacked(entryPoint
);
207 return expectedFPtr
.takeError();
209 if (options
.dumpObjectFile
)
210 engine
->dumpToObjectFile(options
.objectFilename
.empty()
211 ? options
.inputFilename
+ ".o"
212 : options
.objectFilename
);
214 void (*fptr
)(void **) = *expectedFPtr
;
217 return Error::success();
220 static Error
compileAndExecuteVoidFunction(
221 Options
&options
, Operation
*module
, StringRef entryPoint
,
222 CompileAndExecuteConfig config
, std::unique_ptr
<llvm::TargetMachine
> tm
) {
223 auto mainFunction
= dyn_cast_or_null
<LLVM::LLVMFuncOp
>(
224 SymbolTable::lookupSymbolIn(module
, entryPoint
));
225 if (!mainFunction
|| mainFunction
.empty())
226 return makeStringError("entry point not found");
228 auto resultType
= dyn_cast
<LLVM::LLVMVoidType
>(
229 mainFunction
.getFunctionType().getReturnType());
231 return makeStringError("expected void function");
233 void *empty
= nullptr;
234 return compileAndExecute(options
, module
, entryPoint
, std::move(config
),
235 &empty
, std::move(tm
));
238 template <typename Type
>
239 Error
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction
);
241 Error checkCompatibleReturnType
<int32_t>(LLVM::LLVMFuncOp mainFunction
) {
242 auto resultType
= dyn_cast
<IntegerType
>(
243 cast
<LLVM::LLVMFunctionType
>(mainFunction
.getFunctionType())
245 if (!resultType
|| resultType
.getWidth() != 32)
246 return makeStringError("only single i32 function result supported");
247 return Error::success();
250 Error checkCompatibleReturnType
<int64_t>(LLVM::LLVMFuncOp mainFunction
) {
251 auto resultType
= dyn_cast
<IntegerType
>(
252 cast
<LLVM::LLVMFunctionType
>(mainFunction
.getFunctionType())
254 if (!resultType
|| resultType
.getWidth() != 64)
255 return makeStringError("only single i64 function result supported");
256 return Error::success();
259 Error checkCompatibleReturnType
<float>(LLVM::LLVMFuncOp mainFunction
) {
260 if (!isa
<Float32Type
>(
261 cast
<LLVM::LLVMFunctionType
>(mainFunction
.getFunctionType())
263 return makeStringError("only single f32 function result supported");
264 return Error::success();
266 template <typename Type
>
267 Error
compileAndExecuteSingleReturnFunction(
268 Options
&options
, Operation
*module
, StringRef entryPoint
,
269 CompileAndExecuteConfig config
, std::unique_ptr
<llvm::TargetMachine
> tm
) {
270 auto mainFunction
= dyn_cast_or_null
<LLVM::LLVMFuncOp
>(
271 SymbolTable::lookupSymbolIn(module
, entryPoint
));
272 if (!mainFunction
|| mainFunction
.isExternal())
273 return makeStringError("entry point not found");
275 if (cast
<LLVM::LLVMFunctionType
>(mainFunction
.getFunctionType())
276 .getNumParams() != 0)
277 return makeStringError("function inputs not supported");
279 if (Error error
= checkCompatibleReturnType
<Type
>(mainFunction
))
288 compileAndExecute(options
, module
, entryPoint
, std::move(config
),
289 (void **)&data
, std::move(tm
)))
292 // Intentional printing of the output so we can test.
293 llvm::outs() << res
<< '\n';
295 return Error::success();
298 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
299 /// standard C++ main functions.
300 int mlir::JitRunnerMain(int argc
, char **argv
, const DialectRegistry
®istry
,
301 JitRunnerConfig config
) {
302 llvm::ExitOnError exitOnErr
;
304 // Create the options struct containing the command line options for the
305 // runner. This must come before the command line options are parsed.
307 llvm::cl::ParseCommandLineOptions(argc
, argv
, "MLIR CPU execution driver\n");
309 if (options
.hostSupportsJit
) {
310 auto j
= llvm::orc::LLJITBuilder().create();
312 llvm::outs() << "true\n";
314 llvm::outs() << "false\n";
315 exitOnErr(j
.takeError());
320 std::optional
<unsigned> optLevel
= getCommandLineOptLevel(options
);
321 SmallVector
<std::reference_wrapper
<llvm::cl::opt
<bool>>, 4> optFlags
{
322 options
.optO0
, options
.optO1
, options
.optO2
, options
.optO3
};
324 MLIRContext
context(registry
);
326 auto m
= parseMLIRInput(options
.inputFilename
, !options
.noImplicitModule
,
329 llvm::errs() << "could not parse the input IR\n";
333 JitRunnerOptions runnerOptions
{options
.mainFuncName
, options
.mainFuncType
};
334 if (config
.mlirTransformer
)
335 if (failed(config
.mlirTransformer(m
.get(), runnerOptions
)))
338 auto tmBuilderOrError
= llvm::orc::JITTargetMachineBuilder::detectHost();
339 if (!tmBuilderOrError
) {
340 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
344 // Configure TargetMachine builder based on the command line options
345 llvm::SubtargetFeatures features
;
346 if (!options
.mAttrs
.empty()) {
347 for (StringRef attr
: options
.mAttrs
)
348 features
.AddFeature(attr
);
349 tmBuilderOrError
->addFeatures(features
.getFeatures());
352 if (!options
.mArch
.empty()) {
353 tmBuilderOrError
->getTargetTriple().setArchName(options
.mArch
);
356 // Build TargetMachine
357 auto tmOrError
= tmBuilderOrError
->createTargetMachine();
360 llvm::errs() << "Failed to create a TargetMachine for the host\n";
361 exitOnErr(tmOrError
.takeError());
365 llvm::dbgs() << " JITTargetMachineBuilder is "
366 << llvm::orc::JITTargetMachineBuilderPrinter(*tmBuilderOrError
,
370 CompileAndExecuteConfig compileAndExecuteConfig
;
372 compileAndExecuteConfig
.transformer
= mlir::makeOptimizingTransformer(
373 *optLevel
, /*sizeLevel=*/0, /*targetMachine=*/tmOrError
->get());
375 compileAndExecuteConfig
.llvmModuleBuilder
= config
.llvmModuleBuilder
;
376 compileAndExecuteConfig
.runtimeSymbolMap
= config
.runtimesymbolMap
;
378 // Get the function used to compile and execute the module.
379 using CompileAndExecuteFnT
=
380 Error (*)(Options
&, Operation
*, StringRef
, CompileAndExecuteConfig
,
381 std::unique_ptr
<llvm::TargetMachine
> tm
);
382 auto compileAndExecuteFn
=
383 StringSwitch
<CompileAndExecuteFnT
>(options
.mainFuncType
.getValue())
384 .Case("i32", compileAndExecuteSingleReturnFunction
<int32_t>)
385 .Case("i64", compileAndExecuteSingleReturnFunction
<int64_t>)
386 .Case("f32", compileAndExecuteSingleReturnFunction
<float>)
387 .Case("void", compileAndExecuteVoidFunction
)
390 Error error
= compileAndExecuteFn
391 ? compileAndExecuteFn(
392 options
, m
.get(), options
.mainFuncName
.getValue(),
393 compileAndExecuteConfig
, std::move(tmOrError
.get()))
394 : makeStringError("unsupported function type");
396 int exitCode
= EXIT_SUCCESS
;
397 llvm::handleAllErrors(std::move(error
),
398 [&exitCode
](const llvm::ErrorInfoBase
&info
) {
399 llvm::errs() << "Error: ";
400 info
.log(llvm::errs());
401 llvm::errs() << '\n';
402 exitCode
= EXIT_FAILURE
;