[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / unittests / IR / SymbolTableTest.cpp
blob5dcec749f0f425976acb16aa40b0390a748b92a1
1 //===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "mlir/IR/SymbolTable.h"
9 #include "mlir/IR/BuiltinOps.h"
10 #include "mlir/IR/Verifier.h"
11 #include "mlir/Interfaces/CallInterfaces.h"
12 #include "mlir/Interfaces/FunctionInterfaces.h"
13 #include "mlir/Parser/Parser.h"
15 #include "gtest/gtest.h"
17 using namespace mlir;
19 namespace test {
20 void registerTestDialect(DialectRegistry &);
21 } // namespace test
23 class ReplaceAllSymbolUsesTest : public ::testing::Test {
24 protected:
25 using ReplaceFnType = llvm::function_ref<LogicalResult(
26 SymbolTable, ModuleOp, Operation *, Operation *)>;
28 void SetUp() override {
29 ::test::registerTestDialect(registry);
30 context = std::make_unique<MLIRContext>(registry);
33 void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34 // Set up IR and find func ops.
35 OwningOpRef<ModuleOp> module =
36 parseSourceString<ModuleOp>(kInput, context.get());
37 SymbolTable symbolTable(module.get());
38 auto opIterator = module->getBody(0)->getOperations().begin();
39 auto fooOp = cast<FunctionOpInterface>(opIterator++);
40 auto barOp = cast<FunctionOpInterface>(opIterator++);
41 ASSERT_EQ(fooOp.getNameAttr(), "foo");
42 ASSERT_EQ(barOp.getNameAttr(), "bar");
44 // Call test function that does symbol replacement.
45 LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46 ASSERT_TRUE(succeeded(res));
47 ASSERT_TRUE(succeeded(verify(module.get())));
49 // Check that it got renamed.
50 bool calleeFound = false;
51 fooOp->walk([&](CallOpInterface callOp) {
52 StringAttr callee = callOp.getCallableForCallee()
53 .dyn_cast<SymbolRefAttr>()
54 .getLeafReference();
55 EXPECT_EQ(callee, "baz");
56 calleeFound = true;
57 });
58 EXPECT_TRUE(calleeFound);
61 std::unique_ptr<MLIRContext> context;
63 private:
64 constexpr static llvm::StringLiteral kInput = R"MLIR(
65 module {
66 test.conversion_func_op private @foo() {
67 "test.conversion_call_op"() { callee=@bar } : () -> ()
68 "test.return"() : () -> ()
70 test.conversion_func_op private @bar()
72 )MLIR";
74 DialectRegistry registry;
77 namespace {
79 TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
80 // Symbol as `Operation *`, rename within module.
81 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
82 auto barOp) -> LogicalResult {
83 return symbolTable.replaceAllSymbolUses(
84 barOp, StringAttr::get(context.get(), "baz"), module);
85 });
88 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
89 // Symbol as `StringAttr`, rename within module.
90 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
91 auto barOp) -> LogicalResult {
92 return symbolTable.replaceAllSymbolUses(
93 StringAttr::get(context.get(), "bar"),
94 StringAttr::get(context.get(), "baz"), module);
95 });
98 TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
99 // Symbol as `Operation *`, rename within module body.
100 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101 auto barOp) -> LogicalResult {
102 return symbolTable.replaceAllSymbolUses(
103 barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
107 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108 // Symbol as `StringAttr`, rename within module body.
109 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110 auto barOp) -> LogicalResult {
111 return symbolTable.replaceAllSymbolUses(
112 StringAttr::get(context.get(), "bar"),
113 StringAttr::get(context.get(), "baz"), &module->getRegion(0));
117 TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
118 // Symbol as `Operation *`, rename within function.
119 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120 auto barOp) -> LogicalResult {
121 return symbolTable.replaceAllSymbolUses(
122 barOp, StringAttr::get(context.get(), "baz"), fooOp);
126 TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
127 // Symbol as `StringAttr`, rename within function.
128 testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129 auto barOp) -> LogicalResult {
130 return symbolTable.replaceAllSymbolUses(
131 StringAttr::get(context.get(), "bar"),
132 StringAttr::get(context.get(), "baz"), fooOp);
136 } // namespace