1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands
)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands
)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount
,
28 MutableOperandRange forwardedOperands
)
29 : producedOperandCount(producedOperandCount
),
30 forwardedOperands(std::move(forwardedOperands
)) {}
32 //===----------------------------------------------------------------------===//
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional
<BlockArgument
>
40 detail::getBranchSuccessorArgument(const SuccessorOperands
&operands
,
41 unsigned operandIndex
, Block
*successor
) {
42 OperandRange forwardedOperands
= operands
.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands
.empty())
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart
= forwardedOperands
.getBeginOperandIndex();
49 if (operandIndex
< operandsStart
||
50 operandIndex
>= (operandsStart
+ forwardedOperands
.size()))
53 // Index the successor.
55 operands
.getProducedOperandCount() + operandIndex
- operandsStart
;
56 return successor
->getArgument(argIndex
);
59 /// Verify that the given operands match those of the given successor block.
61 detail::verifyBranchSuccessorOperands(Operation
*op
, unsigned succNo
,
62 const SuccessorOperands
&operands
) {
64 unsigned operandCount
= operands
.size();
65 Block
*destBB
= op
->getSuccessor(succNo
);
66 if (operandCount
!= destBB
->getNumArguments())
67 return op
->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB
->getNumArguments();
73 for (unsigned i
= operands
.getProducedOperandCount(); i
!= operandCount
;
75 if (!cast
<BranchOpInterface
>(op
).areTypesCompatible(
76 operands
[i
].getType(), destBB
->getArgument(i
).getType()))
77 return op
->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo
;
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 /// Verify that types match along all region control flow edges originating from
88 /// `sourceNo` (region # if source is a region, std::nullopt if source is parent
89 /// op). `getInputsTypesForRegion` is a function that returns the types of the
90 /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
91 /// the exact type match verification is not necessary (e.g., if the Op verifies
92 /// the match itself).
93 static LogicalResult
verifyTypesAlongAllEdges(
94 Operation
*op
, std::optional
<unsigned> sourceNo
,
95 function_ref
<std::optional
<TypeRange
>(std::optional
<unsigned>)>
96 getInputsTypesForRegion
) {
97 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
99 SmallVector
<RegionSuccessor
, 2> successors
;
100 regionInterface
.getSuccessorRegions(sourceNo
, successors
);
102 for (RegionSuccessor
&succ
: successors
) {
103 std::optional
<unsigned> succRegionNo
;
104 if (!succ
.isParent())
105 succRegionNo
= succ
.getSuccessor()->getRegionNumber();
107 auto printEdgeName
= [&](InFlightDiagnostic
&diag
) -> InFlightDiagnostic
& {
110 diag
<< "Region #" << sourceNo
.value();
112 diag
<< "parent operands";
116 diag
<< "Region #" << succRegionNo
.value();
118 diag
<< "parent results";
122 std::optional
<TypeRange
> sourceTypes
=
123 getInputsTypesForRegion(succRegionNo
);
124 if (!sourceTypes
.has_value())
127 TypeRange succInputsTypes
= succ
.getSuccessorInputs().getTypes();
128 if (sourceTypes
->size() != succInputsTypes
.size()) {
129 InFlightDiagnostic diag
= op
->emitOpError(" region control flow edge ");
130 return printEdgeName(diag
) << ": source has " << sourceTypes
->size()
131 << " operands, but target successor needs "
132 << succInputsTypes
.size();
135 for (const auto &typesIdx
:
136 llvm::enumerate(llvm::zip(*sourceTypes
, succInputsTypes
))) {
137 Type sourceType
= std::get
<0>(typesIdx
.value());
138 Type inputType
= std::get
<1>(typesIdx
.value());
139 if (!regionInterface
.areTypesCompatible(sourceType
, inputType
)) {
140 InFlightDiagnostic diag
= op
->emitOpError(" along control flow edge ");
141 return printEdgeName(diag
)
142 << ": source type #" << typesIdx
.index() << " " << sourceType
143 << " should match input type #" << typesIdx
.index() << " "
151 /// Verify that types match along control flow edges described the given op.
152 LogicalResult
detail::verifyTypesAlongControlFlowEdges(Operation
*op
) {
153 auto regionInterface
= cast
<RegionBranchOpInterface
>(op
);
155 auto inputTypesFromParent
=
156 [&](std::optional
<unsigned> regionNo
) -> TypeRange
{
157 return regionInterface
.getSuccessorEntryOperands(regionNo
).getTypes();
160 // Verify types along control flow edges originating from the parent.
161 if (failed(verifyTypesAlongAllEdges(op
, std::nullopt
, inputTypesFromParent
)))
164 auto areTypesCompatible
= [&](TypeRange lhs
, TypeRange rhs
) {
165 if (lhs
.size() != rhs
.size())
167 for (auto types
: llvm::zip(lhs
, rhs
)) {
168 if (!regionInterface
.areTypesCompatible(std::get
<0>(types
),
169 std::get
<1>(types
))) {
176 // Verify types along control flow edges originating from each region.
177 for (unsigned regionNo
: llvm::seq(0U, op
->getNumRegions())) {
178 Region
®ion
= op
->getRegion(regionNo
);
180 // Since there can be multiple `ReturnLike` terminators or others
181 // implementing the `RegionBranchTerminatorOpInterface`, all should have the
182 // same operand types when passing them to the same region.
184 std::optional
<OperandRange
> regionReturnOperands
;
185 for (Block
&block
: region
) {
186 Operation
*terminator
= block
.getTerminator();
187 auto terminatorOperands
=
188 getRegionBranchSuccessorOperands(terminator
, regionNo
);
189 if (!terminatorOperands
)
192 if (!regionReturnOperands
) {
193 regionReturnOperands
= terminatorOperands
;
197 // Found more than one ReturnLike terminator. Make sure the operand types
198 // match with the first one.
199 if (!areTypesCompatible(regionReturnOperands
->getTypes(),
200 terminatorOperands
->getTypes()))
201 return op
->emitOpError("Region #")
203 << " operands mismatch between return-like terminators";
206 auto inputTypesFromRegion
=
207 [&](std::optional
<unsigned> regionNo
) -> std::optional
<TypeRange
> {
208 // If there is no return-like terminator, the op itself should verify
210 if (!regionReturnOperands
)
213 // All successors get the same set of operand types.
214 return TypeRange(regionReturnOperands
->getTypes());
217 if (failed(verifyTypesAlongAllEdges(op
, regionNo
, inputTypesFromRegion
)))
224 /// Return `true` if region `r` is reachable from region `begin` according to
225 /// the RegionBranchOpInterface (by taking a branch).
226 static bool isRegionReachable(Region
*begin
, Region
*r
) {
227 assert(begin
->getParentOp() == r
->getParentOp() &&
228 "expected that both regions belong to the same op");
229 auto op
= cast
<RegionBranchOpInterface
>(begin
->getParentOp());
230 SmallVector
<bool> visited(op
->getNumRegions(), false);
231 visited
[begin
->getRegionNumber()] = true;
233 // Retrieve all successors of the region and enqueue them in the worklist.
234 SmallVector
<unsigned> worklist
;
235 auto enqueueAllSuccessors
= [&](unsigned index
) {
236 SmallVector
<RegionSuccessor
> successors
;
237 op
.getSuccessorRegions(index
, successors
);
238 for (RegionSuccessor successor
: successors
)
239 if (!successor
.isParent())
240 worklist
.push_back(successor
.getSuccessor()->getRegionNumber());
242 enqueueAllSuccessors(begin
->getRegionNumber());
244 // Process all regions in the worklist via DFS.
245 while (!worklist
.empty()) {
246 unsigned nextRegion
= worklist
.pop_back_val();
247 if (nextRegion
== r
->getRegionNumber())
249 if (visited
[nextRegion
])
251 visited
[nextRegion
] = true;
252 enqueueAllSuccessors(nextRegion
);
258 /// Return `true` if `a` and `b` are in mutually exclusive regions.
260 /// 1. Find the first common of `a` and `b` (ancestor) that implements
261 /// RegionBranchOpInterface.
262 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
264 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
265 /// mutually exclusive if they are not reachable from each other as per
266 /// RegionBranchOpInterface::getSuccessorRegions.
267 bool mlir::insideMutuallyExclusiveRegions(Operation
*a
, Operation
*b
) {
268 assert(a
&& "expected non-empty operation");
269 assert(b
&& "expected non-empty operation");
271 auto branchOp
= a
->getParentOfType
<RegionBranchOpInterface
>();
273 // Check if b is inside branchOp. (We already know that a is.)
274 if (!branchOp
->isProperAncestor(b
)) {
275 // Check next enclosing RegionBranchOpInterface.
276 branchOp
= branchOp
->getParentOfType
<RegionBranchOpInterface
>();
280 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
282 Region
*regionA
= nullptr, *regionB
= nullptr;
283 for (Region
&r
: branchOp
->getRegions()) {
284 if (r
.findAncestorOpInRegion(*a
)) {
285 assert(!regionA
&& "already found a region for a");
288 if (r
.findAncestorOpInRegion(*b
)) {
289 assert(!regionB
&& "already found a region for b");
293 assert(regionA
&& regionB
&& "could not find region of op");
295 // `a` and `b` are in mutually exclusive regions if both regions are
296 // distinct and neither region is reachable from the other region.
297 return regionA
!= regionB
&& !isRegionReachable(regionA
, regionB
) &&
298 !isRegionReachable(regionB
, regionA
);
301 // Could not find a common RegionBranchOpInterface among a's and b's
306 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index
) {
307 Region
*region
= &getOperation()->getRegion(index
);
308 return isRegionReachable(region
, region
);
311 void RegionBranchOpInterface::getSuccessorRegions(
312 std::optional
<unsigned> index
, SmallVectorImpl
<RegionSuccessor
> ®ions
) {
313 unsigned numInputs
= 0;
315 // If the predecessor is a region, get the number of operands from an
316 // exiting terminator in the region.
317 for (Block
&block
: getOperation()->getRegion(*index
)) {
318 Operation
*terminator
= block
.getTerminator();
319 if (getRegionBranchSuccessorOperands(terminator
, *index
)) {
320 numInputs
= terminator
->getNumOperands();
325 // Otherwise, use the number of parent operation operands.
326 numInputs
= getOperation()->getNumOperands();
328 SmallVector
<Attribute
, 2> operands(numInputs
, nullptr);
329 getSuccessorRegions(index
, operands
, regions
);
332 Region
*mlir::getEnclosingRepetitiveRegion(Operation
*op
) {
333 while (Region
*region
= op
->getParentRegion()) {
334 op
= region
->getParentOp();
335 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
336 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
342 Region
*mlir::getEnclosingRepetitiveRegion(Value value
) {
343 Region
*region
= value
.getParentRegion();
345 Operation
*op
= region
->getParentOp();
346 if (auto branchOp
= dyn_cast
<RegionBranchOpInterface
>(op
))
347 if (branchOp
.isRepetitiveRegion(region
->getRegionNumber()))
349 region
= op
->getParentRegion();
354 //===----------------------------------------------------------------------===//
355 // RegionBranchTerminatorOpInterface
356 //===----------------------------------------------------------------------===//
358 /// Returns true if the given operation is either annotated with the
359 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
360 bool mlir::isRegionReturnLike(Operation
*operation
) {
361 return dyn_cast
<RegionBranchTerminatorOpInterface
>(operation
) ||
362 operation
->hasTrait
<OpTrait::ReturnLike
>();
365 /// Returns the mutable operands that are passed to the region with the given
366 /// `regionIndex`. If the operation does not implement the
367 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
368 /// result will be `std::nullopt`. In all other cases, the resulting
369 /// `OperandRange` represents all operands that are passed to the specified
370 /// successor region. If `regionIndex` is `std::nullopt`, all operands that are
371 /// passed to the parent operation will be returned.
372 std::optional
<MutableOperandRange
>
373 mlir::getMutableRegionBranchSuccessorOperands(
374 Operation
*operation
, std::optional
<unsigned> regionIndex
) {
375 // Try to query a RegionBranchTerminatorOpInterface to determine
376 // all successor operands that will be passed to the successor
378 if (auto regionTerminatorInterface
=
379 dyn_cast
<RegionBranchTerminatorOpInterface
>(operation
))
380 return regionTerminatorInterface
.getMutableSuccessorOperands(regionIndex
);
382 // TODO: The ReturnLike trait should imply a default implementation of the
383 // RegionBranchTerminatorOpInterface. This would make this code significantly
384 // easier. Furthermore, this may even make this function obsolete.
385 if (operation
->hasTrait
<OpTrait::ReturnLike
>())
386 return MutableOperandRange(operation
);
390 /// Returns the read only operands that are passed to the region with the given
391 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
393 std::optional
<OperandRange
>
394 mlir::getRegionBranchSuccessorOperands(Operation
*operation
,
395 std::optional
<unsigned> regionIndex
) {
396 auto range
= getMutableRegionBranchSuccessorOperands(operation
, regionIndex
);
398 return range
->operator OperandRange();