1 //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 #include "mlir/IR/SymbolTable.h"
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/Verifier.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
13 #include "mlir/Parser/Parser.h"
15 #include "gtest/gtest.h"
20 void registerTestDialect(DialectRegistry
&);
23 class ReplaceAllSymbolUsesTest
: public ::testing::Test
{
25 using ReplaceFnType
= llvm::function_ref
<LogicalResult(
26 SymbolTable
, ModuleOp
, Operation
*, Operation
*)>;
28 void SetUp() override
{
29 ::test::registerTestDialect(registry
);
30 context
= std::make_unique
<MLIRContext
>(registry
);
33 void testReplaceAllSymbolUses(ReplaceFnType replaceFn
) {
34 // Set up IR and find func ops.
35 OwningOpRef
<ModuleOp
> module
=
36 parseSourceString
<ModuleOp
>(kInput
, context
.get());
37 SymbolTable
symbolTable(module
.get());
38 auto opIterator
= module
->getBody(0)->getOperations().begin();
39 auto fooOp
= cast
<FunctionOpInterface
>(opIterator
++);
40 auto barOp
= cast
<FunctionOpInterface
>(opIterator
++);
41 ASSERT_EQ(fooOp
.getNameAttr(), "foo");
42 ASSERT_EQ(barOp
.getNameAttr(), "bar");
44 // Call test function that does symbol replacement.
45 LogicalResult res
= replaceFn(symbolTable
, module
.get(), fooOp
, barOp
);
46 ASSERT_TRUE(succeeded(res
));
47 ASSERT_TRUE(succeeded(verify(module
.get())));
49 // Check that it got renamed.
50 bool calleeFound
= false;
51 fooOp
->walk([&](CallOpInterface callOp
) {
52 StringAttr callee
= callOp
.getCallableForCallee()
53 .dyn_cast
<SymbolRefAttr
>()
55 EXPECT_EQ(callee
, "baz");
58 EXPECT_TRUE(calleeFound
);
61 std::unique_ptr
<MLIRContext
> context
;
64 constexpr static llvm::StringLiteral kInput
= R
"MLIR(
66 test.conversion_func_op private @foo() {
67 "test
.conversion_call_op
"() { callee=@bar } : () -> ()
68 "test
.return"() : () -> ()
70 test.conversion_func_op private @bar()
74 DialectRegistry registry
;
79 TEST_F(ReplaceAllSymbolUsesTest
, OperationInModuleOp
) {
80 // Symbol as `Operation *`, rename within module.
81 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
82 auto barOp
) -> LogicalResult
{
83 return symbolTable
.replaceAllSymbolUses(
84 barOp
, StringAttr::get(context
.get(), "baz"), module
);
88 TEST_F(ReplaceAllSymbolUsesTest
, StringAttrInModuleOp
) {
89 // Symbol as `StringAttr`, rename within module.
90 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
91 auto barOp
) -> LogicalResult
{
92 return symbolTable
.replaceAllSymbolUses(
93 StringAttr::get(context
.get(), "bar"),
94 StringAttr::get(context
.get(), "baz"), module
);
98 TEST_F(ReplaceAllSymbolUsesTest
, OperationInModuleBody
) {
99 // Symbol as `Operation *`, rename within module body.
100 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
101 auto barOp
) -> LogicalResult
{
102 return symbolTable
.replaceAllSymbolUses(
103 barOp
, StringAttr::get(context
.get(), "baz"), &module
->getRegion(0));
107 TEST_F(ReplaceAllSymbolUsesTest
, StringAttrInModuleBody
) {
108 // Symbol as `StringAttr`, rename within module body.
109 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
110 auto barOp
) -> LogicalResult
{
111 return symbolTable
.replaceAllSymbolUses(
112 StringAttr::get(context
.get(), "bar"),
113 StringAttr::get(context
.get(), "baz"), &module
->getRegion(0));
117 TEST_F(ReplaceAllSymbolUsesTest
, OperationInFuncOp
) {
118 // Symbol as `Operation *`, rename within function.
119 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
120 auto barOp
) -> LogicalResult
{
121 return symbolTable
.replaceAllSymbolUses(
122 barOp
, StringAttr::get(context
.get(), "baz"), fooOp
);
126 TEST_F(ReplaceAllSymbolUsesTest
, StringAttrInFuncOp
) {
127 // Symbol as `StringAttr`, rename within function.
128 testReplaceAllSymbolUses([&](auto symbolTable
, auto module
, auto fooOp
,
129 auto barOp
) -> LogicalResult
{
130 return symbolTable
.replaceAllSymbolUses(
131 StringAttr::get(context
.get(), "bar"),
132 StringAttr::get(context
.get(), "baz"), fooOp
);