[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / Interfaces / ControlFlowInterfaces.cpp
blob4ed024ddae247b03ffbe0daf1d95995a2da3dca2
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28 MutableOperandRange forwardedOperands)
29 : producedOperandCount(producedOperandCount),
30 forwardedOperands(std::move(forwardedOperands)) {}
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41 unsigned operandIndex, Block *successor) {
42 OperandRange forwardedOperands = operands.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands.empty())
45 return std::nullopt;
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49 if (operandIndex < operandsStart ||
50 operandIndex >= (operandsStart + forwardedOperands.size()))
51 return std::nullopt;
53 // Index the successor.
54 unsigned argIndex =
55 operands.getProducedOperandCount() + operandIndex - operandsStart;
56 return successor->getArgument(argIndex);
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62 const SuccessorOperands &operands) {
63 // Check the count.
64 unsigned operandCount = operands.size();
65 Block *destBB = op->getSuccessor(succNo);
66 if (operandCount != destBB->getNumArguments())
67 return op->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB->getNumArguments();
72 // Check the types.
73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74 ++i) {
75 if (!cast<BranchOpInterface>(op).areTypesCompatible(
76 operands[i].getType(), destBB->getArgument(i).getType()))
77 return op->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo;
80 return success();
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
88 RegionBranchPoint sourceNo,
89 RegionBranchPoint succRegionNo) {
90 diag << "from ";
91 if (Region *region = sourceNo.getRegionOrNull())
92 diag << "Region #" << region->getRegionNumber();
93 else
94 diag << "parent operands";
96 diag << " to ";
97 if (Region *region = succRegionNo.getRegionOrNull())
98 diag << "Region #" << region->getRegionNumber();
99 else
100 diag << "parent results";
101 return diag;
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
107 static LogicalResult
108 verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
109 function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
110 getInputsTypesForRegion) {
111 auto regionInterface = cast<RegionBranchOpInterface>(op);
113 SmallVector<RegionSuccessor, 2> successors;
114 regionInterface.getSuccessorRegions(sourcePoint, successors);
116 for (RegionSuccessor &succ : successors) {
117 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118 if (failed(sourceTypes))
119 return failure();
121 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122 if (sourceTypes->size() != succInputsTypes.size()) {
123 InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
124 return printRegionEdgeName(diag, sourcePoint, succ)
125 << ": source has " << sourceTypes->size()
126 << " operands, but target successor needs "
127 << succInputsTypes.size();
130 for (const auto &typesIdx :
131 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
132 Type sourceType = std::get<0>(typesIdx.value());
133 Type inputType = std::get<1>(typesIdx.value());
134 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135 InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
136 return printRegionEdgeName(diag, sourcePoint, succ)
137 << ": source type #" << typesIdx.index() << " " << sourceType
138 << " should match input type #" << typesIdx.index() << " "
139 << inputType;
143 return success();
146 /// Verify that types match along control flow edges described the given op.
147 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
148 auto regionInterface = cast<RegionBranchOpInterface>(op);
150 auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151 return regionInterface.getEntrySuccessorOperands(point).getTypes();
154 // Verify types along control flow edges originating from the parent.
155 if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
156 inputTypesFromParent)))
157 return failure();
159 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160 if (lhs.size() != rhs.size())
161 return false;
162 for (auto types : llvm::zip(lhs, rhs)) {
163 if (!regionInterface.areTypesCompatible(std::get<0>(types),
164 std::get<1>(types))) {
165 return false;
168 return true;
171 // Verify types along control flow edges originating from each region.
172 for (Region &region : op->getRegions()) {
174 // Since there can be multiple terminators implementing the
175 // `RegionBranchTerminatorOpInterface`, all should have the same operand
176 // types when passing them to the same region.
178 SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
179 for (Block &block : region)
180 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
181 block.getTerminator()))
182 regionReturnOps.push_back(terminator);
184 // If there is no return-like terminator, the op itself should verify
185 // type consistency.
186 if (regionReturnOps.empty())
187 continue;
189 auto inputTypesForRegion =
190 [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
191 std::optional<OperandRange> regionReturnOperands;
192 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
193 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
195 if (!regionReturnOperands) {
196 regionReturnOperands = terminatorOperands;
197 continue;
200 // Found more than one ReturnLike terminator. Make sure the operand
201 // types match with the first one.
202 if (!areTypesCompatible(regionReturnOperands->getTypes(),
203 terminatorOperands.getTypes())) {
204 InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
205 return printRegionEdgeName(diag, region, point)
206 << " operands mismatch between return-like terminators";
210 // All successors get the same set of operand types.
211 return TypeRange(regionReturnOperands->getTypes());
214 if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
215 return failure();
218 return success();
221 /// Return `true` if region `r` is reachable from region `begin` according to
222 /// the RegionBranchOpInterface (by taking a branch).
223 static bool isRegionReachable(Region *begin, Region *r) {
224 assert(begin->getParentOp() == r->getParentOp() &&
225 "expected that both regions belong to the same op");
226 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
227 SmallVector<bool> visited(op->getNumRegions(), false);
228 visited[begin->getRegionNumber()] = true;
230 // Retrieve all successors of the region and enqueue them in the worklist.
231 SmallVector<Region *> worklist;
232 auto enqueueAllSuccessors = [&](Region *region) {
233 SmallVector<RegionSuccessor> successors;
234 op.getSuccessorRegions(region, successors);
235 for (RegionSuccessor successor : successors)
236 if (!successor.isParent())
237 worklist.push_back(successor.getSuccessor());
239 enqueueAllSuccessors(begin);
241 // Process all regions in the worklist via DFS.
242 while (!worklist.empty()) {
243 Region *nextRegion = worklist.pop_back_val();
244 if (nextRegion == r)
245 return true;
246 if (visited[nextRegion->getRegionNumber()])
247 continue;
248 visited[nextRegion->getRegionNumber()] = true;
249 enqueueAllSuccessors(nextRegion);
252 return false;
255 /// Return `true` if `a` and `b` are in mutually exclusive regions.
257 /// 1. Find the first common of `a` and `b` (ancestor) that implements
258 /// RegionBranchOpInterface.
259 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
260 /// contained.
261 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
262 /// mutually exclusive if they are not reachable from each other as per
263 /// RegionBranchOpInterface::getSuccessorRegions.
264 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
265 assert(a && "expected non-empty operation");
266 assert(b && "expected non-empty operation");
268 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
269 while (branchOp) {
270 // Check if b is inside branchOp. (We already know that a is.)
271 if (!branchOp->isProperAncestor(b)) {
272 // Check next enclosing RegionBranchOpInterface.
273 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
274 continue;
277 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
278 // are contained.
279 Region *regionA = nullptr, *regionB = nullptr;
280 for (Region &r : branchOp->getRegions()) {
281 if (r.findAncestorOpInRegion(*a)) {
282 assert(!regionA && "already found a region for a");
283 regionA = &r;
285 if (r.findAncestorOpInRegion(*b)) {
286 assert(!regionB && "already found a region for b");
287 regionB = &r;
290 assert(regionA && regionB && "could not find region of op");
292 // `a` and `b` are in mutually exclusive regions if both regions are
293 // distinct and neither region is reachable from the other region.
294 return regionA != regionB && !isRegionReachable(regionA, regionB) &&
295 !isRegionReachable(regionB, regionA);
298 // Could not find a common RegionBranchOpInterface among a's and b's
299 // ancestors.
300 return false;
303 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
304 Region *region = &getOperation()->getRegion(index);
305 return isRegionReachable(region, region);
308 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
309 while (Region *region = op->getParentRegion()) {
310 op = region->getParentOp();
311 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
312 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
313 return region;
315 return nullptr;
318 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
319 Region *region = value.getParentRegion();
320 while (region) {
321 Operation *op = region->getParentOp();
322 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
323 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
324 return region;
325 region = op->getParentRegion();
327 return nullptr;