1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands
)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands
)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount
,
28 MutableOperandRange forwardedOperands
)
29 : producedOperandCount(producedOperandCount
),
30 forwardedOperands(std::move(forwardedOperands
)) {}
32 //===----------------------------------------------------------------------===//
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional
<BlockArgument
>
40 detail::getBranchSuccessorArgument(const SuccessorOperands
&operands
,
41 unsigned operandIndex
, Block
*successor
) {
42 OperandRange forwardedOperands
= operands
.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands
.empty())
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart
= forwardedOperands
.getBeginOperandIndex();
49 if (operandIndex
< operandsStart
||
50 operandIndex
>= (operandsStart
+ forwardedOperands
.size()))
53 // Index the successor.
55 operands
.getProducedOperandCount() + operandIndex
- operandsStart
;
56 return successor
->getArgument(argIndex
);
59 /// Verify that the given operands match those of the given successor block.
61 detail::verifyBranchSuccessorOperands(Operation
*op
, unsigned succNo
,
62 const SuccessorOperands
&operands
) {
64 unsigned operandCount
= operands
.size();
65 Block
*destBB
= op
->getSuccessor(succNo
);
66 if (operandCount
!= destBB
->getNumArguments())
67 return op
->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB
->getNumArguments();
73 for (unsigned i
= operands
.getProducedOperandCount(); i
!= operandCount
;
75 if (!cast
<BranchOpInterface
>(op
).areTypesCompatible(
76 operands
[i
].getType(), destBB
->getArgument(i
).getType()))
77 return op
->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo
;
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 static InFlightDiagnostic
&printRegionEdgeName(InFlightDiagnostic
&diag
,
88 RegionBranchPoint sourceNo
,
89 RegionBranchPoint succRegionNo
) {
91 if (Region
*region
= sourceNo
.getRegionOrNull())
92 diag
<< "Region #" << region
->getRegionNumber();
94 diag
<< "parent operands";
97 if (Region
*region
= succRegionNo
.getRegionOrNull())
98 diag
<< "Region #" << region
->getRegionNumber();
100 diag
<< "parent results";
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
108 verifyTypesAlongAllEdges(Operation
*op
, RegionBranchPoint sourcePoint
,
109 function_ref
<FailureOr
<TypeRange
>(RegionBranchPoint
)>
110 getInputsTypesForRegion
) {
111 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
113 SmallVector
<RegionSuccessor
, 2> successors
;
114 regionInterface
.getSuccessorRegions(sourcePoint
, successors
);
116 for (RegionSuccessor
&succ
: successors
) {
117 FailureOr
<TypeRange
> sourceTypes
= getInputsTypesForRegion(succ
);
118 if (failed(sourceTypes
))
121 TypeRange succInputsTypes
= succ
.getSuccessorInputs().getTypes();
122 if (sourceTypes
->size() != succInputsTypes
.size()) {
123 InFlightDiagnostic diag
= op
->emitOpError(" region control flow edge ");
124 return printRegionEdgeName(diag
, sourcePoint
, succ
)
125 << ": source has " << sourceTypes
->size()
126 << " operands, but target successor needs "
127 << succInputsTypes
.size();
130 for (const auto &typesIdx
:
131 llvm::enumerate(llvm::zip(*sourceTypes
, succInputsTypes
))) {
132 Type sourceType
= std::get
<0>(typesIdx
.value());
133 Type inputType
= std::get
<1>(typesIdx
.value());
134 if (!regionInterface
.areTypesCompatible(sourceType
, inputType
)) {
135 InFlightDiagnostic diag
= op
->emitOpError(" along control flow edge ");
136 return printRegionEdgeName(diag
, sourcePoint
, succ
)
137 << ": source type #" << typesIdx
.index() << " " << sourceType
138 << " should match input type #" << typesIdx
.index() << " "
146 /// Verify that types match along control flow edges described the given op.
147 LogicalResult
detail::verifyTypesAlongControlFlowEdges(Operation
*op
) {
148 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
150 auto inputTypesFromParent
= [&](RegionBranchPoint point
) -> TypeRange
{
151 return regionInterface
.getEntrySuccessorOperands(point
).getTypes();
154 // Verify types along control flow edges originating from the parent.
155 if (failed(verifyTypesAlongAllEdges(op
, RegionBranchPoint::parent(),
156 inputTypesFromParent
)))
159 auto areTypesCompatible
= [&](TypeRange lhs
, TypeRange rhs
) {
160 if (lhs
.size() != rhs
.size())
162 for (auto types
: llvm::zip(lhs
, rhs
)) {
163 if (!regionInterface
.areTypesCompatible(std::get
<0>(types
),
164 std::get
<1>(types
))) {
171 // Verify types along control flow edges originating from each region.
172 for (Region
®ion
: op
->getRegions()) {
174 // Since there can be multiple terminators implementing the
175 // `RegionBranchTerminatorOpInterface`, all should have the same operand
176 // types when passing them to the same region.
178 SmallVector
<RegionBranchTerminatorOpInterface
> regionReturnOps
;
179 for (Block
&block
: region
)
180 if (auto terminator
= dyn_cast
<RegionBranchTerminatorOpInterface
>(
181 block
.getTerminator()))
182 regionReturnOps
.push_back(terminator
);
184 // If there is no return-like terminator, the op itself should verify
186 if (regionReturnOps
.empty())
189 auto inputTypesForRegion
=
190 [&](RegionBranchPoint point
) -> FailureOr
<TypeRange
> {
191 std::optional
<OperandRange
> regionReturnOperands
;
192 for (RegionBranchTerminatorOpInterface regionReturnOp
: regionReturnOps
) {
193 auto terminatorOperands
= regionReturnOp
.getSuccessorOperands(point
);
195 if (!regionReturnOperands
) {
196 regionReturnOperands
= terminatorOperands
;
200 // Found more than one ReturnLike terminator. Make sure the operand
201 // types match with the first one.
202 if (!areTypesCompatible(regionReturnOperands
->getTypes(),
203 terminatorOperands
.getTypes())) {
204 InFlightDiagnostic diag
= op
->emitOpError(" along control flow edge");
205 return printRegionEdgeName(diag
, region
, point
)
206 << " operands mismatch between return-like terminators";
210 // All successors get the same set of operand types.
211 return TypeRange(regionReturnOperands
->getTypes());
214 if (failed(verifyTypesAlongAllEdges(op
, region
, inputTypesForRegion
)))
221 /// Return `true` if region `r` is reachable from region `begin` according to
222 /// the RegionBranchOpInterface (by taking a branch).
223 static bool isRegionReachable(Region
*begin
, Region
*r
) {
224 assert(begin
->getParentOp() == r
->getParentOp() &&
225 "expected that both regions belong to the same op");
226 auto op
= cast
<RegionBranchOpInterface
>(begin
->getParentOp());
227 SmallVector
<bool> visited(op
->getNumRegions(), false);
228 visited
[begin
->getRegionNumber()] = true;
230 // Retrieve all successors of the region and enqueue them in the worklist.
231 SmallVector
<Region
*> worklist
;
232 auto enqueueAllSuccessors
= [&](Region
*region
) {
233 SmallVector
<RegionSuccessor
> successors
;
234 op
.getSuccessorRegions(region
, successors
);
235 for (RegionSuccessor successor
: successors
)
236 if (!successor
.isParent())
237 worklist
.push_back(successor
.getSuccessor());
239 enqueueAllSuccessors(begin
);
241 // Process all regions in the worklist via DFS.
242 while (!worklist
.empty()) {
243 Region
*nextRegion
= worklist
.pop_back_val();
246 if (visited
[nextRegion
->getRegionNumber()])
248 visited
[nextRegion
->getRegionNumber()] = true;
249 enqueueAllSuccessors(nextRegion
);
255 /// Return `true` if `a` and `b` are in mutually exclusive regions.
257 /// 1. Find the first common of `a` and `b` (ancestor) that implements
258 /// RegionBranchOpInterface.
259 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
261 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
262 /// mutually exclusive if they are not reachable from each other as per
263 /// RegionBranchOpInterface::getSuccessorRegions.
264 bool mlir::insideMutuallyExclusiveRegions(Operation
*a
, Operation
*b
) {
265 assert(a
&& "expected non-empty operation");
266 assert(b
&& "expected non-empty operation");
268 auto branchOp
= a
->getParentOfType
<RegionBranchOpInterface
>();
270 // Check if b is inside branchOp. (We already know that a is.)
271 if (!branchOp
->isProperAncestor(b
)) {
272 // Check next enclosing RegionBranchOpInterface.
273 branchOp
= branchOp
->getParentOfType
<RegionBranchOpInterface
>();
277 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
279 Region
*regionA
= nullptr, *regionB
= nullptr;
280 for (Region
&r
: branchOp
->getRegions()) {
281 if (r
.findAncestorOpInRegion(*a
)) {
282 assert(!regionA
&& "already found a region for a");
285 if (r
.findAncestorOpInRegion(*b
)) {
286 assert(!regionB
&& "already found a region for b");
290 assert(regionA
&& regionB
&& "could not find region of op");
292 // `a` and `b` are in mutually exclusive regions if both regions are
293 // distinct and neither region is reachable from the other region.
294 return regionA
!= regionB
&& !isRegionReachable(regionA
, regionB
) &&
295 !isRegionReachable(regionB
, regionA
);
298 // Could not find a common RegionBranchOpInterface among a's and b's
303 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index
) {
304 Region
*region
= &getOperation()->getRegion(index
);
305 return isRegionReachable(region
, region
);
308 Region
*mlir::getEnclosingRepetitiveRegion(Operation
*op
) {
309 while (Region
*region
= op
->getParentRegion()) {
310 op
= region
->getParentOp();
311 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
312 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
318 Region
*mlir::getEnclosingRepetitiveRegion(Value value
) {
319 Region
*region
= value
.getParentRegion();
321 Operation
*op
= region
->getParentOp();
322 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
323 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
325 region
= op
->getParentRegion();