[LLD][COFF] Emit tail merge pdata for delay load thunks on ARM64EC (#116810)
[llvm-project.git] / mlir / unittests / IR / AttrTypeReplacerTest.cpp
blobc7b42eb267c7ade31e29770b9739da8e7cf9f9d1
1 //===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/AttrTypeSubElements.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // CyclicAttrTypeReplacer
19 //===----------------------------------------------------------------------===//
21 TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
22 MLIRContext ctx;
24 CyclicAttrTypeReplacer replacer;
25 replacer.addReplacement([&](BoolAttr b) {
26 return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
27 });
29 EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
30 StringAttr::get(&ctx, "true"));
31 EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
32 StringAttr::get(&ctx, "false"));
33 EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
34 mlir::UnitAttr::get(&ctx));
37 TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
38 MLIRContext ctx;
39 Builder b(&ctx);
41 CyclicAttrTypeReplacer replacer;
42 // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
43 replacer.addReplacement([&](IntegerAttr attr) {
44 return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
45 });
46 // The first repeat of any integer attr is pruned into a unit attr.
47 replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
49 // No recursion case.
50 EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
51 mlir::UnitAttr::get(&ctx));
52 // Starting at 0.
53 EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
54 // Starting at 2.
55 EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
58 //===----------------------------------------------------------------------===//
59 // CyclicAttrTypeReplacerTest: ChainRecursion
60 //===----------------------------------------------------------------------===//
62 class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
63 public:
64 CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
65 // IntegerType<width = N>
66 // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
67 // This will create a chain of infinite length without recursion pruning.
68 replacer.addReplacement([&](mlir::IntegerType intType) {
69 ++invokeCount;
70 return b.getFunctionType(
71 {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
72 });
75 void setBaseCase(std::optional<unsigned> pruneAt) {
76 replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
77 return (!pruneAt || intType.getWidth() == *pruneAt)
78 ? std::make_optional(b.getIndexType())
79 : std::nullopt;
80 });
83 Type getFunctionTypeChain(unsigned N) {
84 Type type = b.getIndexType();
85 for (unsigned i = 0; i < N; i++)
86 type = b.getFunctionType({}, type);
87 return type;
90 MLIRContext ctx;
91 Builder b;
92 CyclicAttrTypeReplacer replacer;
93 int invokeCount = 0;
96 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
97 setBaseCase(std::nullopt);
99 // No recursion case.
100 EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
101 EXPECT_EQ(invokeCount, 0);
103 // Starting at 0. Cycle length is 3.
104 invokeCount = 0;
105 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
106 getFunctionTypeChain(3));
107 EXPECT_EQ(invokeCount, 3);
109 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
110 invokeCount = 0;
111 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
112 getFunctionTypeChain(5));
113 EXPECT_EQ(invokeCount, 2);
116 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) {
117 setBaseCase(std::nullopt);
119 // Starting at 1. Cycle length is 3.
120 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
121 getFunctionTypeChain(3));
122 EXPECT_EQ(invokeCount, 3);
125 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) {
126 setBaseCase(0);
128 // Starting at 0. Cycle length is 3.
129 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
130 getFunctionTypeChain(3));
131 EXPECT_EQ(invokeCount, 3);
134 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) {
135 setBaseCase(0);
137 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
138 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
139 getFunctionTypeChain(5));
140 EXPECT_EQ(invokeCount, 5);
143 //===----------------------------------------------------------------------===//
144 // CyclicAttrTypeReplacerTest: BranchingRecusion
145 //===----------------------------------------------------------------------===//
147 class CyclicAttrTypeReplacerBranchingRecusionPruningTest
148 : public ::testing::Test {
149 public:
150 CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) {
151 // IntegerType<width = N>
152 // ==> FunctionType<
153 // IntegerType< width = (N+1) % 3> =>
154 // IntegerType< width = (N+1) % 3>>.
155 // This will create a binary tree of infinite depth without pruning.
156 replacer.addReplacement([&](mlir::IntegerType intType) {
157 ++invokeCount;
158 Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3);
159 return b.getFunctionType({child}, {child});
163 void setBaseCase(std::optional<unsigned> pruneAt) {
164 replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
165 return (!pruneAt || intType.getWidth() == *pruneAt)
166 ? std::make_optional(b.getIndexType())
167 : std::nullopt;
171 Type getFunctionTypeTree(unsigned N) {
172 Type type = b.getIndexType();
173 for (unsigned i = 0; i < N; i++)
174 type = b.getFunctionType(type, type);
175 return type;
178 MLIRContext ctx;
179 Builder b;
180 CyclicAttrTypeReplacer replacer;
181 int invokeCount = 0;
184 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) {
185 setBaseCase(std::nullopt);
187 // No recursion case.
188 EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
189 EXPECT_EQ(invokeCount, 0);
191 // Starting at 0. Cycle length is 3.
192 invokeCount = 0;
193 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
194 getFunctionTypeTree(3));
195 // Since both branches are identical, this should incur linear invocations
196 // of the replacement function instead of exponential.
197 EXPECT_EQ(invokeCount, 3);
199 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
200 invokeCount = 0;
201 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
202 getFunctionTypeTree(5));
203 EXPECT_EQ(invokeCount, 2);
206 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) {
207 setBaseCase(std::nullopt);
209 // Starting at 1. Cycle length is 3.
210 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
211 getFunctionTypeTree(3));
212 EXPECT_EQ(invokeCount, 3);
215 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) {
216 setBaseCase(0);
218 // Starting at 0. Cycle length is 3.
219 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
220 getFunctionTypeTree(3));
221 EXPECT_EQ(invokeCount, 3);
224 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) {
225 setBaseCase(0);
227 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
228 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
229 getFunctionTypeTree(5));
230 EXPECT_EQ(invokeCount, 5);