1 //===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Interfaces/ControlFlowInterfaces.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/IR/Dialect.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/Parser/Parser.h"
17 #include <gtest/gtest.h>
21 /// A dummy op that is also a terminator.
22 struct DummyOp
: public Op
<DummyOp
, OpTrait::IsTerminator
> {
24 static ArrayRef
<StringRef
> getAttributeNames() { return {}; }
26 static StringRef
getOperationName() { return "cftest.dummy_op"; }
29 /// All regions of this op are mutually exclusive.
30 struct MutuallyExclusiveRegionsOp
31 : public Op
<MutuallyExclusiveRegionsOp
, RegionBranchOpInterface::Trait
> {
33 static ArrayRef
<StringRef
> getAttributeNames() { return {}; }
35 static StringRef
getOperationName() {
36 return "cftest.mutually_exclusive_regions_op";
39 // Regions have no successors.
40 void getSuccessorRegions(RegionBranchPoint point
,
41 SmallVectorImpl
<RegionSuccessor
> ®ions
) {}
44 /// All regions of this op call each other in a large circle.
46 : public Op
<LoopRegionsOp
, RegionBranchOpInterface::Trait
> {
48 static const unsigned kNumRegions
= 3;
50 static ArrayRef
<StringRef
> getAttributeNames() { return {}; }
52 static StringRef
getOperationName() { return "cftest.loop_regions_op"; }
54 void getSuccessorRegions(RegionBranchPoint point
,
55 SmallVectorImpl
<RegionSuccessor
> ®ions
) {
56 if (Region
*region
= point
.getRegionOrNull()) {
57 if (point
== (*this)->getRegion(1))
58 // This region also branches back to the parent.
59 regions
.push_back(RegionSuccessor());
60 regions
.push_back(RegionSuccessor(region
));
65 /// Each region branches back it itself or the parent.
66 struct DoubleLoopRegionsOp
67 : public Op
<DoubleLoopRegionsOp
, RegionBranchOpInterface::Trait
> {
70 static ArrayRef
<StringRef
> getAttributeNames() { return {}; }
72 static StringRef
getOperationName() {
73 return "cftest.double_loop_regions_op";
76 void getSuccessorRegions(RegionBranchPoint point
,
77 SmallVectorImpl
<RegionSuccessor
> ®ions
) {
78 if (Region
*region
= point
.getRegionOrNull()) {
79 regions
.push_back(RegionSuccessor());
80 regions
.push_back(RegionSuccessor(region
));
85 /// Regions are executed sequentially.
86 struct SequentialRegionsOp
87 : public Op
<SequentialRegionsOp
, RegionBranchOpInterface::Trait
> {
89 static ArrayRef
<StringRef
> getAttributeNames() { return {}; }
91 static StringRef
getOperationName() { return "cftest.sequential_regions_op"; }
93 // Region 0 has Region 1 as a successor.
94 void getSuccessorRegions(RegionBranchPoint point
,
95 SmallVectorImpl
<RegionSuccessor
> ®ions
) {
96 if (point
== (*this)->getRegion(0)) {
97 Operation
*thisOp
= this->getOperation();
98 regions
.push_back(RegionSuccessor(&thisOp
->getRegion(1)));
103 /// A dialect putting all the above together.
104 struct CFTestDialect
: Dialect
{
105 explicit CFTestDialect(MLIRContext
*ctx
)
106 : Dialect(getDialectNamespace(), ctx
, TypeID::get
<CFTestDialect
>()) {
107 addOperations
<DummyOp
, MutuallyExclusiveRegionsOp
, LoopRegionsOp
,
108 DoubleLoopRegionsOp
, SequentialRegionsOp
>();
110 static StringRef
getDialectNamespace() { return "cftest"; }
113 TEST(RegionBranchOpInterface
, MutuallyExclusiveOps
) {
114 const char *ir
= R
"MLIR(
115 "cftest
.mutually_exclusive_regions_op
"() (
116 {"cftest
.dummy_op
"() : () -> ()}, // op1
117 {"cftest
.dummy_op
"() : () -> ()} // op2
121 DialectRegistry registry
;
122 registry
.insert
<CFTestDialect
>();
123 MLIRContext
ctx(registry
);
125 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
126 Operation
*testOp
= &module
->getBody()->getOperations().front();
127 Operation
*op1
= &testOp
->getRegion(0).front().front();
128 Operation
*op2
= &testOp
->getRegion(1).front().front();
130 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1
, op2
));
131 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2
, op1
));
134 TEST(RegionBranchOpInterface
, MutuallyExclusiveOps2
) {
135 const char *ir
= R
"MLIR(
136 "cftest
.double_loop_regions_op
"() (
137 {"cftest
.dummy_op
"() : () -> ()}, // op1
138 {"cftest
.dummy_op
"() : () -> ()} // op2
142 DialectRegistry registry
;
143 registry
.insert
<CFTestDialect
>();
144 MLIRContext
ctx(registry
);
146 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
147 Operation
*testOp
= &module
->getBody()->getOperations().front();
148 Operation
*op1
= &testOp
->getRegion(0).front().front();
149 Operation
*op2
= &testOp
->getRegion(1).front().front();
151 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1
, op2
));
152 EXPECT_TRUE(insideMutuallyExclusiveRegions(op2
, op1
));
155 TEST(RegionBranchOpInterface
, NotMutuallyExclusiveOps
) {
156 const char *ir
= R
"MLIR(
157 "cftest
.sequential_regions_op
"() (
158 {"cftest
.dummy_op
"() : () -> ()}, // op1
159 {"cftest
.dummy_op
"() : () -> ()} // op2
163 DialectRegistry registry
;
164 registry
.insert
<CFTestDialect
>();
165 MLIRContext
ctx(registry
);
167 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
168 Operation
*testOp
= &module
->getBody()->getOperations().front();
169 Operation
*op1
= &testOp
->getRegion(0).front().front();
170 Operation
*op2
= &testOp
->getRegion(1).front().front();
172 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1
, op2
));
173 EXPECT_FALSE(insideMutuallyExclusiveRegions(op2
, op1
));
176 TEST(RegionBranchOpInterface
, NestedMutuallyExclusiveOps
) {
177 const char *ir
= R
"MLIR(
178 "cftest
.mutually_exclusive_regions_op
"() (
180 "cftest
.sequential_regions_op
"() (
181 {"cftest
.dummy_op
"() : () -> ()}, // op1
182 {"cftest
.dummy_op
"() : () -> ()} // op3
184 "cftest
.dummy_op
"() : () -> ()
186 {"cftest
.dummy_op
"() : () -> ()} // op2
190 DialectRegistry registry
;
191 registry
.insert
<CFTestDialect
>();
192 MLIRContext
ctx(registry
);
194 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
195 Operation
*testOp
= &module
->getBody()->getOperations().front();
197 &testOp
->getRegion(0).front().front().getRegion(0).front().front();
198 Operation
*op2
= &testOp
->getRegion(1).front().front();
200 &testOp
->getRegion(0).front().front().getRegion(1).front().front();
202 EXPECT_TRUE(insideMutuallyExclusiveRegions(op1
, op2
));
203 EXPECT_TRUE(insideMutuallyExclusiveRegions(op3
, op2
));
204 EXPECT_FALSE(insideMutuallyExclusiveRegions(op1
, op3
));
207 TEST(RegionBranchOpInterface
, RecursiveRegions
) {
208 const char *ir
= R
"MLIR(
209 "cftest
.loop_regions_op
"() (
210 {"cftest
.dummy_op
"() : () -> ()}, // op1
211 {"cftest
.dummy_op
"() : () -> ()}, // op2
212 {"cftest
.dummy_op
"() : () -> ()} // op3
216 DialectRegistry registry
;
217 registry
.insert
<CFTestDialect
>();
218 MLIRContext
ctx(registry
);
220 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
221 Operation
*testOp
= &module
->getBody()->getOperations().front();
222 auto regionOp
= cast
<RegionBranchOpInterface
>(testOp
);
223 Operation
*op1
= &testOp
->getRegion(0).front().front();
224 Operation
*op2
= &testOp
->getRegion(1).front().front();
225 Operation
*op3
= &testOp
->getRegion(2).front().front();
227 EXPECT_TRUE(regionOp
.isRepetitiveRegion(0));
228 EXPECT_TRUE(regionOp
.isRepetitiveRegion(1));
229 EXPECT_TRUE(regionOp
.isRepetitiveRegion(2));
230 EXPECT_NE(getEnclosingRepetitiveRegion(op1
), nullptr);
231 EXPECT_NE(getEnclosingRepetitiveRegion(op2
), nullptr);
232 EXPECT_NE(getEnclosingRepetitiveRegion(op3
), nullptr);
235 TEST(RegionBranchOpInterface
, NotRecursiveRegions
) {
236 const char *ir
= R
"MLIR(
237 "cftest
.sequential_regions_op
"() (
238 {"cftest
.dummy_op
"() : () -> ()}, // op1
239 {"cftest
.dummy_op
"() : () -> ()} // op2
243 DialectRegistry registry
;
244 registry
.insert
<CFTestDialect
>();
245 MLIRContext
ctx(registry
);
247 OwningOpRef
<ModuleOp
> module
= parseSourceString
<ModuleOp
>(ir
, &ctx
);
248 Operation
*testOp
= &module
->getBody()->getOperations().front();
249 Operation
*op1
= &testOp
->getRegion(0).front().front();
250 Operation
*op2
= &testOp
->getRegion(1).front().front();
252 EXPECT_EQ(getEnclosingRepetitiveRegion(op1
), nullptr);
253 EXPECT_EQ(getEnclosingRepetitiveRegion(op2
), nullptr);