1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands
)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands
)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount
,
28 MutableOperandRange forwardedOperands
)
29 : producedOperandCount(producedOperandCount
),
30 forwardedOperands(std::move(forwardedOperands
)) {}
32 //===----------------------------------------------------------------------===//
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional
<BlockArgument
>
40 detail::getBranchSuccessorArgument(const SuccessorOperands
&operands
,
41 unsigned operandIndex
, Block
*successor
) {
42 OperandRange forwardedOperands
= operands
.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands
.empty())
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart
= forwardedOperands
.getBeginOperandIndex();
49 if (operandIndex
< operandsStart
||
50 operandIndex
>= (operandsStart
+ forwardedOperands
.size()))
53 // Index the successor.
55 operands
.getProducedOperandCount() + operandIndex
- operandsStart
;
56 return successor
->getArgument(argIndex
);
59 /// Verify that the given operands match those of the given successor block.
61 detail::verifyBranchSuccessorOperands(Operation
*op
, unsigned succNo
,
62 const SuccessorOperands
&operands
) {
64 unsigned operandCount
= operands
.size();
65 Block
*destBB
= op
->getSuccessor(succNo
);
66 if (operandCount
!= destBB
->getNumArguments())
67 return op
->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB
->getNumArguments();
73 for (unsigned i
= operands
.getProducedOperandCount(); i
!= operandCount
;
75 if (!cast
<BranchOpInterface
>(op
).areTypesCompatible(
76 operands
[i
].getType(), destBB
->getArgument(i
).getType()))
77 return op
->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo
;
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 static InFlightDiagnostic
&printRegionEdgeName(InFlightDiagnostic
&diag
,
88 RegionBranchPoint sourceNo
,
89 RegionBranchPoint succRegionNo
) {
91 if (Region
*region
= sourceNo
.getRegionOrNull())
92 diag
<< "Region #" << region
->getRegionNumber();
94 diag
<< "parent operands";
97 if (Region
*region
= succRegionNo
.getRegionOrNull())
98 diag
<< "Region #" << region
->getRegionNumber();
100 diag
<< "parent results";
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
108 verifyTypesAlongAllEdges(Operation
*op
, RegionBranchPoint sourcePoint
,
109 function_ref
<FailureOr
<TypeRange
>(RegionBranchPoint
)>
110 getInputsTypesForRegion
) {
111 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
113 SmallVector
<RegionSuccessor
, 2> successors
;
114 regionInterface
.getSuccessorRegions(sourcePoint
, successors
);
116 for (RegionSuccessor
&succ
: successors
) {
117 FailureOr
<TypeRange
> sourceTypes
= getInputsTypesForRegion(succ
);
118 if (failed(sourceTypes
))
121 TypeRange succInputsTypes
= succ
.getSuccessorInputs().getTypes();
122 if (sourceTypes
->size() != succInputsTypes
.size()) {
123 InFlightDiagnostic diag
= op
->emitOpError(" region control flow edge ");
124 return printRegionEdgeName(diag
, sourcePoint
, succ
)
125 << ": source has " << sourceTypes
->size()
126 << " operands, but target successor needs "
127 << succInputsTypes
.size();
130 for (const auto &typesIdx
:
131 llvm::enumerate(llvm::zip(*sourceTypes
, succInputsTypes
))) {
132 Type sourceType
= std::get
<0>(typesIdx
.value());
133 Type inputType
= std::get
<1>(typesIdx
.value());
134 if (!regionInterface
.areTypesCompatible(sourceType
, inputType
)) {
135 InFlightDiagnostic diag
= op
->emitOpError(" along control flow edge ");
136 return printRegionEdgeName(diag
, sourcePoint
, succ
)
137 << ": source type #" << typesIdx
.index() << " " << sourceType
138 << " should match input type #" << typesIdx
.index() << " "
146 /// Verify that types match along control flow edges described the given op.
147 LogicalResult
detail::verifyTypesAlongControlFlowEdges(Operation
*op
) {
148 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
150 auto inputTypesFromParent
= [&](RegionBranchPoint point
) -> TypeRange
{
151 return regionInterface
.getEntrySuccessorOperands(point
).getTypes();
154 // Verify types along control flow edges originating from the parent.
155 if (failed(verifyTypesAlongAllEdges(op
, RegionBranchPoint::parent(),
156 inputTypesFromParent
)))
159 auto areTypesCompatible
= [&](TypeRange lhs
, TypeRange rhs
) {
160 if (lhs
.size() != rhs
.size())
162 for (auto types
: llvm::zip(lhs
, rhs
)) {
163 if (!regionInterface
.areTypesCompatible(std::get
<0>(types
),
164 std::get
<1>(types
))) {
171 // Verify types along control flow edges originating from each region.
172 for (Region
®ion
: op
->getRegions()) {
174 // Since there can be multiple terminators implementing the
175 // `RegionBranchTerminatorOpInterface`, all should have the same operand
176 // types when passing them to the same region.
178 SmallVector
<RegionBranchTerminatorOpInterface
> regionReturnOps
;
179 for (Block
&block
: region
)
181 if (auto terminator
=
182 dyn_cast
<RegionBranchTerminatorOpInterface
>(block
.back()))
183 regionReturnOps
.push_back(terminator
);
185 // If there is no return-like terminator, the op itself should verify
187 if (regionReturnOps
.empty())
190 auto inputTypesForRegion
=
191 [&](RegionBranchPoint point
) -> FailureOr
<TypeRange
> {
192 std::optional
<OperandRange
> regionReturnOperands
;
193 for (RegionBranchTerminatorOpInterface regionReturnOp
: regionReturnOps
) {
194 auto terminatorOperands
= regionReturnOp
.getSuccessorOperands(point
);
196 if (!regionReturnOperands
) {
197 regionReturnOperands
= terminatorOperands
;
201 // Found more than one ReturnLike terminator. Make sure the operand
202 // types match with the first one.
203 if (!areTypesCompatible(regionReturnOperands
->getTypes(),
204 terminatorOperands
.getTypes())) {
205 InFlightDiagnostic diag
= op
->emitOpError(" along control flow edge");
206 return printRegionEdgeName(diag
, region
, point
)
207 << " operands mismatch between return-like terminators";
211 // All successors get the same set of operand types.
212 return TypeRange(regionReturnOperands
->getTypes());
215 if (failed(verifyTypesAlongAllEdges(op
, region
, inputTypesForRegion
)))
222 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223 /// this function returns "true" for a successor region. The first parameter is
224 /// the successor region. The second parameter indicates all already visited
226 using StopConditionFn
= function_ref
<bool(Region
*, ArrayRef
<bool> visited
)>;
228 /// Traverse the region graph starting at `begin`. The traversal is interrupted
229 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
230 /// this function returns "true". Otherwise, if the traversal was not
231 /// interrupted, this function returns "false".
232 static bool traverseRegionGraph(Region
*begin
,
233 StopConditionFn stopConditionFn
) {
234 auto op
= cast
<RegionBranchOpInterface
>(begin
->getParentOp());
235 SmallVector
<bool> visited(op
->getNumRegions(), false);
236 visited
[begin
->getRegionNumber()] = true;
238 // Retrieve all successors of the region and enqueue them in the worklist.
239 SmallVector
<Region
*> worklist
;
240 auto enqueueAllSuccessors
= [&](Region
*region
) {
241 SmallVector
<RegionSuccessor
> successors
;
242 op
.getSuccessorRegions(region
, successors
);
243 for (RegionSuccessor successor
: successors
)
244 if (!successor
.isParent())
245 worklist
.push_back(successor
.getSuccessor());
247 enqueueAllSuccessors(begin
);
249 // Process all regions in the worklist via DFS.
250 while (!worklist
.empty()) {
251 Region
*nextRegion
= worklist
.pop_back_val();
252 if (stopConditionFn(nextRegion
, visited
))
254 if (visited
[nextRegion
->getRegionNumber()])
256 visited
[nextRegion
->getRegionNumber()] = true;
257 enqueueAllSuccessors(nextRegion
);
263 /// Return `true` if region `r` is reachable from region `begin` according to
264 /// the RegionBranchOpInterface (by taking a branch).
265 static bool isRegionReachable(Region
*begin
, Region
*r
) {
266 assert(begin
->getParentOp() == r
->getParentOp() &&
267 "expected that both regions belong to the same op");
268 return traverseRegionGraph(begin
,
269 [&](Region
*nextRegion
, ArrayRef
<bool> visited
) {
270 // Interrupt traversal if `r` was reached.
271 return nextRegion
== r
;
275 /// Return `true` if `a` and `b` are in mutually exclusive regions.
277 /// 1. Find the first common of `a` and `b` (ancestor) that implements
278 /// RegionBranchOpInterface.
279 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
281 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
282 /// mutually exclusive if they are not reachable from each other as per
283 /// RegionBranchOpInterface::getSuccessorRegions.
284 bool mlir::insideMutuallyExclusiveRegions(Operation
*a
, Operation
*b
) {
285 assert(a
&& "expected non-empty operation");
286 assert(b
&& "expected non-empty operation");
288 auto branchOp
= a
->getParentOfType
<RegionBranchOpInterface
>();
290 // Check if b is inside branchOp. (We already know that a is.)
291 if (!branchOp
->isProperAncestor(b
)) {
292 // Check next enclosing RegionBranchOpInterface.
293 branchOp
= branchOp
->getParentOfType
<RegionBranchOpInterface
>();
297 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
299 Region
*regionA
= nullptr, *regionB
= nullptr;
300 for (Region
&r
: branchOp
->getRegions()) {
301 if (r
.findAncestorOpInRegion(*a
)) {
302 assert(!regionA
&& "already found a region for a");
305 if (r
.findAncestorOpInRegion(*b
)) {
306 assert(!regionB
&& "already found a region for b");
310 assert(regionA
&& regionB
&& "could not find region of op");
312 // `a` and `b` are in mutually exclusive regions if both regions are
313 // distinct and neither region is reachable from the other region.
314 return regionA
!= regionB
&& !isRegionReachable(regionA
, regionB
) &&
315 !isRegionReachable(regionB
, regionA
);
318 // Could not find a common RegionBranchOpInterface among a's and b's
323 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index
) {
324 Region
*region
= &getOperation()->getRegion(index
);
325 return isRegionReachable(region
, region
);
328 bool RegionBranchOpInterface::hasLoop() {
329 SmallVector
<RegionSuccessor
> entryRegions
;
330 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions
);
331 for (RegionSuccessor successor
: entryRegions
)
332 if (!successor
.isParent() &&
333 traverseRegionGraph(successor
.getSuccessor(),
334 [](Region
*nextRegion
, ArrayRef
<bool> visited
) {
335 // Interrupt traversal if the region was already
337 return visited
[nextRegion
->getRegionNumber()];
343 Region
*mlir::getEnclosingRepetitiveRegion(Operation
*op
) {
344 while (Region
*region
= op
->getParentRegion()) {
345 op
= region
->getParentOp();
346 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
347 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
353 Region
*mlir::getEnclosingRepetitiveRegion(Value value
) {
354 Region
*region
= value
.getParentRegion();
356 Operation
*op
= region
->getParentOp();
357 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
358 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
360 region
= op
->getParentRegion();