1 //===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/AttrTypeSubElements.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "gtest/gtest.h"
17 //===----------------------------------------------------------------------===//
18 // CyclicAttrTypeReplacer
19 //===----------------------------------------------------------------------===//
21 TEST(CyclicAttrTypeReplacerTest
, testNoRecursion
) {
24 CyclicAttrTypeReplacer replacer
;
25 replacer
.addReplacement([&](BoolAttr b
) {
26 return StringAttr::get(&ctx
, b
.getValue() ? "true" : "false");
29 EXPECT_EQ(replacer
.replace(BoolAttr::get(&ctx
, true)),
30 StringAttr::get(&ctx
, "true"));
31 EXPECT_EQ(replacer
.replace(BoolAttr::get(&ctx
, false)),
32 StringAttr::get(&ctx
, "false"));
33 EXPECT_EQ(replacer
.replace(mlir::UnitAttr::get(&ctx
)),
34 mlir::UnitAttr::get(&ctx
));
37 TEST(CyclicAttrTypeReplacerTest
, testInPlaceRecursionPruneAnywhere
) {
41 CyclicAttrTypeReplacer replacer
;
42 // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
43 replacer
.addReplacement([&](IntegerAttr attr
) {
44 return replacer
.replace(b
.getI8IntegerAttr((attr
.getInt() + 1) % 3));
46 // The first repeat of any integer attr is pruned into a unit attr.
47 replacer
.addCycleBreaker([&](IntegerAttr attr
) { return b
.getUnitAttr(); });
50 EXPECT_EQ(replacer
.replace(mlir::UnitAttr::get(&ctx
)),
51 mlir::UnitAttr::get(&ctx
));
53 EXPECT_EQ(replacer
.replace(b
.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx
));
55 EXPECT_EQ(replacer
.replace(b
.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx
));
58 //===----------------------------------------------------------------------===//
59 // CyclicAttrTypeReplacerTest: ChainRecursion
60 //===----------------------------------------------------------------------===//
62 class CyclicAttrTypeReplacerChainRecursionPruningTest
: public ::testing::Test
{
64 CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx
) {
65 // IntegerType<width = N>
66 // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
67 // This will create a chain of infinite length without recursion pruning.
68 replacer
.addReplacement([&](mlir::IntegerType intType
) {
70 return b
.getFunctionType(
71 {}, {mlir::IntegerType::get(&ctx
, (intType
.getWidth() + 1) % 3)});
75 void setBaseCase(std::optional
<unsigned> pruneAt
) {
76 replacer
.addCycleBreaker([&, pruneAt
](mlir::IntegerType intType
) {
77 return (!pruneAt
|| intType
.getWidth() == *pruneAt
)
78 ? std::make_optional(b
.getIndexType())
83 Type
getFunctionTypeChain(unsigned N
) {
84 Type type
= b
.getIndexType();
85 for (unsigned i
= 0; i
< N
; i
++)
86 type
= b
.getFunctionType({}, type
);
92 CyclicAttrTypeReplacer replacer
;
96 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest
, testPruneAnywhere0
) {
97 setBaseCase(std::nullopt
);
100 EXPECT_EQ(replacer
.replace(b
.getIndexType()), b
.getIndexType());
101 EXPECT_EQ(invokeCount
, 0);
103 // Starting at 0. Cycle length is 3.
105 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 0)),
106 getFunctionTypeChain(3));
107 EXPECT_EQ(invokeCount
, 3);
109 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
111 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
112 getFunctionTypeChain(5));
113 EXPECT_EQ(invokeCount
, 2);
116 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest
, testPruneAnywhere1
) {
117 setBaseCase(std::nullopt
);
119 // Starting at 1. Cycle length is 3.
120 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
121 getFunctionTypeChain(3));
122 EXPECT_EQ(invokeCount
, 3);
125 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest
, testPruneSpecific0
) {
128 // Starting at 0. Cycle length is 3.
129 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 0)),
130 getFunctionTypeChain(3));
131 EXPECT_EQ(invokeCount
, 3);
134 TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest
, testPruneSpecific1
) {
137 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
138 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
139 getFunctionTypeChain(5));
140 EXPECT_EQ(invokeCount
, 5);
143 //===----------------------------------------------------------------------===//
144 // CyclicAttrTypeReplacerTest: BranchingRecusion
145 //===----------------------------------------------------------------------===//
147 class CyclicAttrTypeReplacerBranchingRecusionPruningTest
148 : public ::testing::Test
{
150 CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx
) {
151 // IntegerType<width = N>
153 // IntegerType< width = (N+1) % 3> =>
154 // IntegerType< width = (N+1) % 3>>.
155 // This will create a binary tree of infinite depth without pruning.
156 replacer
.addReplacement([&](mlir::IntegerType intType
) {
158 Type child
= mlir::IntegerType::get(&ctx
, (intType
.getWidth() + 1) % 3);
159 return b
.getFunctionType({child
}, {child
});
163 void setBaseCase(std::optional
<unsigned> pruneAt
) {
164 replacer
.addCycleBreaker([&, pruneAt
](mlir::IntegerType intType
) {
165 return (!pruneAt
|| intType
.getWidth() == *pruneAt
)
166 ? std::make_optional(b
.getIndexType())
171 Type
getFunctionTypeTree(unsigned N
) {
172 Type type
= b
.getIndexType();
173 for (unsigned i
= 0; i
< N
; i
++)
174 type
= b
.getFunctionType(type
, type
);
180 CyclicAttrTypeReplacer replacer
;
184 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest
, testPruneAnywhere0
) {
185 setBaseCase(std::nullopt
);
187 // No recursion case.
188 EXPECT_EQ(replacer
.replace(b
.getIndexType()), b
.getIndexType());
189 EXPECT_EQ(invokeCount
, 0);
191 // Starting at 0. Cycle length is 3.
193 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 0)),
194 getFunctionTypeTree(3));
195 // Since both branches are identical, this should incur linear invocations
196 // of the replacement function instead of exponential.
197 EXPECT_EQ(invokeCount
, 3);
199 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
201 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
202 getFunctionTypeTree(5));
203 EXPECT_EQ(invokeCount
, 2);
206 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest
, testPruneAnywhere1
) {
207 setBaseCase(std::nullopt
);
209 // Starting at 1. Cycle length is 3.
210 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
211 getFunctionTypeTree(3));
212 EXPECT_EQ(invokeCount
, 3);
215 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest
, testPruneSpecific0
) {
218 // Starting at 0. Cycle length is 3.
219 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 0)),
220 getFunctionTypeTree(3));
221 EXPECT_EQ(invokeCount
, 3);
224 TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest
, testPruneSpecific1
) {
227 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
228 EXPECT_EQ(replacer
.replace(mlir::IntegerType::get(&ctx
, 1)),
229 getFunctionTypeTree(5));
230 EXPECT_EQ(invokeCount
, 5);