1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
12 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
13 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18 #include "mlir/ExecutionEngine/ExecutionEngine.h"
19 #include "mlir/ExecutionEngine/MemRefUtils.h"
20 #include "mlir/ExecutionEngine/RunnerUtils.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/InitAllDialects.h"
23 #include "mlir/Parser/Parser.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27 #include "mlir/Target/LLVMIR/Export.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Support/raw_ostream.h"
31 #include "gmock/gmock.h"
33 // SPARC currently lacks JIT support.
35 #define SKIP_WITHOUT_JIT(x) DISABLED_##x
37 #define SKIP_WITHOUT_JIT(x) x
42 // The JIT isn't supported on Windows at that time
45 static struct LLVMInitializer
{
47 llvm::InitializeNativeTarget();
48 llvm::InitializeNativeTargetAsmPrinter();
52 /// Simple conversion pipeline for the purpose of testing sources written in
53 /// dialects lowering to LLVM Dialect.
54 static LogicalResult
lowerToLLVMDialect(ModuleOp module
) {
55 PassManager
pm(module
->getName());
56 pm
.addPass(mlir::createFinalizeMemRefToLLVMConversionPass());
57 pm
.addNestedPass
<func::FuncOp
>(mlir::createArithToLLVMConversionPass());
58 pm
.addPass(mlir::createConvertFuncToLLVMPass());
59 pm
.addPass(mlir::createReconcileUnrealizedCastsPass());
60 return pm
.run(module
);
63 TEST(MLIRExecutionEngine
, SKIP_WITHOUT_JIT(AddInteger
)) {
65 std::string moduleStr
= R
"mlir(
66 func.func @foo(%arg0 : i32 {llvm.signext}) -> (i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
67 %res = arith.addi %arg0, %arg0 : i32
72 std::string moduleStr
= R
"mlir(
73 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
74 %res = arith.addi %arg0, %arg0 : i32
79 DialectRegistry registry
;
80 registerAllDialects(registry
);
81 registerBuiltinDialectTranslation(registry
);
82 registerLLVMDialectTranslation(registry
);
83 MLIRContext
context(registry
);
84 OwningOpRef
<ModuleOp
> module
=
85 parseSourceString
<ModuleOp
>(moduleStr
, &context
);
86 ASSERT_TRUE(!!module
);
87 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
88 auto jitOrError
= ExecutionEngine::create(*module
);
89 ASSERT_TRUE(!!jitOrError
);
90 std::unique_ptr
<ExecutionEngine
> jit
= std::move(jitOrError
.get());
91 // The result of the function must be passed as output argument.
94 jit
->invoke("foo", 42, ExecutionEngine::Result
<int>(result
));
96 ASSERT_EQ(result
, 42 + 42);
99 TEST(MLIRExecutionEngine
, SKIP_WITHOUT_JIT(SubtractFloat
)) {
100 std::string moduleStr
= R
"mlir(
101 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
102 %res = arith.subf %arg0, %arg1 : f32
106 DialectRegistry registry
;
107 registerAllDialects(registry
);
108 registerBuiltinDialectTranslation(registry
);
109 registerLLVMDialectTranslation(registry
);
110 MLIRContext
context(registry
);
111 OwningOpRef
<ModuleOp
> module
=
112 parseSourceString
<ModuleOp
>(moduleStr
, &context
);
113 ASSERT_TRUE(!!module
);
114 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
115 auto jitOrError
= ExecutionEngine::create(*module
);
116 ASSERT_TRUE(!!jitOrError
);
117 std::unique_ptr
<ExecutionEngine
> jit
= std::move(jitOrError
.get());
118 // The result of the function must be passed as output argument.
121 jit
->invoke("foo", 43.0f
, 1.0f
, ExecutionEngine::result(result
));
123 ASSERT_EQ(result
, 42.f
);
126 TEST(NativeMemRefJit
, SKIP_WITHOUT_JIT(ZeroRankMemref
)) {
127 OwningMemRef
<float, 0> a({});
129 ASSERT_EQ(*a
->data
, 42);
131 std::string moduleStr
= R
"mlir(
132 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
133 %cst42 = arith.constant 42.0 : f32
134 memref.store %cst42, %arg0[] : memref<f32>
138 DialectRegistry registry
;
139 registerAllDialects(registry
);
140 registerBuiltinDialectTranslation(registry
);
141 registerLLVMDialectTranslation(registry
);
142 MLIRContext
context(registry
);
143 auto module
= parseSourceString
<ModuleOp
>(moduleStr
, &context
);
144 ASSERT_TRUE(!!module
);
145 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
146 auto jitOrError
= ExecutionEngine::create(*module
);
147 ASSERT_TRUE(!!jitOrError
);
148 auto jit
= std::move(jitOrError
.get());
150 llvm::Error error
= jit
->invoke("zero_ranked", &*a
);
152 EXPECT_EQ((a
[{}]), 42.);
153 for (float &elt
: *a
)
154 EXPECT_EQ(&elt
, &(a
[{}]));
157 TEST(NativeMemRefJit
, SKIP_WITHOUT_JIT(RankOneMemref
)) {
158 int64_t shape
[] = {9};
159 OwningMemRef
<float, 1> a(shape
);
161 for (float &elt
: *a
) {
162 EXPECT_EQ(&elt
, &(a
[{count
- 1}]));
166 std::string moduleStr
= R
"mlir(
167 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
168 %cst42 = arith.constant 42.0 : f32
169 %cst5 = arith.constant 5 : index
170 memref.store %cst42, %arg0[%cst5] : memref<?xf32>
174 DialectRegistry registry
;
175 registerAllDialects(registry
);
176 registerBuiltinDialectTranslation(registry
);
177 registerLLVMDialectTranslation(registry
);
178 MLIRContext
context(registry
);
179 auto module
= parseSourceString
<ModuleOp
>(moduleStr
, &context
);
180 ASSERT_TRUE(!!module
);
181 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
182 auto jitOrError
= ExecutionEngine::create(*module
);
183 ASSERT_TRUE(!!jitOrError
);
184 auto jit
= std::move(jitOrError
.get());
186 llvm::Error error
= jit
->invoke("one_ranked", &*a
);
189 for (float &elt
: *a
) {
193 EXPECT_EQ(elt
, count
);
198 TEST(NativeMemRefJit
, SKIP_WITHOUT_JIT(BasicMemref
)) {
201 // Prepare arguments beforehand.
202 auto init
= [=](float &elt
, ArrayRef
<int64_t> indices
) {
203 assert(indices
.size() == 2);
204 elt
= m
* indices
[0] + indices
[1];
206 int64_t shape
[] = {k
, m
};
207 int64_t shapeAlloc
[] = {k
+ 1, m
+ 1};
208 OwningMemRef
<float, 2> a(shape
, shapeAlloc
, init
);
209 ASSERT_EQ(a
->sizes
[0], k
);
210 ASSERT_EQ(a
->sizes
[1], m
);
211 ASSERT_EQ(a
->strides
[0], m
+ 1);
212 ASSERT_EQ(a
->strides
[1], 1);
213 for (int i
= 0; i
< k
; ++i
) {
214 for (int j
= 0; j
< m
; ++j
) {
215 EXPECT_EQ((a
[{i
, j
}]), i
* m
+ j
);
216 EXPECT_EQ(&(a
[{i
, j
}]), &((*a
)[i
][j
]));
219 std::string moduleStr
= R
"mlir(
220 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
221 %x = arith.constant 2 : index
222 %y = arith.constant 1 : index
223 %cst42 = arith.constant 42.0 : f32
224 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
225 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
229 DialectRegistry registry
;
230 registerAllDialects(registry
);
231 registerBuiltinDialectTranslation(registry
);
232 registerLLVMDialectTranslation(registry
);
233 MLIRContext
context(registry
);
234 OwningOpRef
<ModuleOp
> module
=
235 parseSourceString
<ModuleOp
>(moduleStr
, &context
);
236 ASSERT_TRUE(!!module
);
237 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
238 auto jitOrError
= ExecutionEngine::create(*module
);
239 ASSERT_TRUE(!!jitOrError
);
240 std::unique_ptr
<ExecutionEngine
> jit
= std::move(jitOrError
.get());
242 llvm::Error error
= jit
->invoke("rank2_memref", &*a
, &*a
);
244 EXPECT_EQ(((*a
)[1][2]), 42.);
245 EXPECT_EQ((a
[{2, 1}]), 42.);
248 // A helper function that will be called from the JIT
249 static void memrefMultiply(::StridedMemRefType
<float, 2> *memref
,
250 int32_t coefficient
) {
251 for (float &elt
: *memref
)
255 // MSAN does not work with JIT.
256 #if __has_feature(memory_sanitizer)
257 #define MAYBE_JITCallback DISABLED_JITCallback
259 #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
261 TEST(NativeMemRefJit
, MAYBE_JITCallback
) {
264 int64_t shape
[] = {k
, m
};
265 int64_t shapeAlloc
[] = {k
+ 1, m
+ 1};
266 OwningMemRef
<float, 2> a(shape
, shapeAlloc
);
268 for (float &elt
: *a
)
272 std::string moduleStr
= R
"mlir(
273 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface }
274 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32 {llvm.signext}) attributes { llvm.emit_c_interface } {
275 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
276 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
281 std::string moduleStr
= R
"mlir(
282 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
283 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
284 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
285 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
291 DialectRegistry registry
;
292 registerAllDialects(registry
);
293 registerBuiltinDialectTranslation(registry
);
294 registerLLVMDialectTranslation(registry
);
295 MLIRContext
context(registry
);
296 auto module
= parseSourceString
<ModuleOp
>(moduleStr
, &context
);
297 ASSERT_TRUE(!!module
);
298 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module
)));
299 auto jitOrError
= ExecutionEngine::create(*module
);
300 ASSERT_TRUE(!!jitOrError
);
301 auto jit
= std::move(jitOrError
.get());
302 // Define any extra symbols so they're available at runtime.
303 jit
->registerSymbols([&](llvm::orc::MangleAndInterner interner
) {
304 llvm::orc::SymbolMap symbolMap
;
305 symbolMap
[interner("_mlir_ciface_callback")] = {
306 llvm::orc::ExecutorAddr::fromPtr(memrefMultiply
),
307 llvm::JITSymbolFlags::Exported
};
311 int32_t coefficient
= 3.;
312 llvm::Error error
= jit
->invoke("caller_for_callback", &*a
, coefficient
);
316 ASSERT_EQ(elt
, coefficient
* count
++);