[TableGen] Fix validateOperandClass for non Phyical Reg (#118146)
[llvm-project.git] / mlir / lib / Interfaces / ControlFlowInterfaces.cpp
blob39e5e9997f7dd19144d403821775f36d355644a5
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28 MutableOperandRange forwardedOperands)
29 : producedOperandCount(producedOperandCount),
30 forwardedOperands(std::move(forwardedOperands)) {}
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41 unsigned operandIndex, Block *successor) {
42 OperandRange forwardedOperands = operands.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands.empty())
45 return std::nullopt;
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49 if (operandIndex < operandsStart ||
50 operandIndex >= (operandsStart + forwardedOperands.size()))
51 return std::nullopt;
53 // Index the successor.
54 unsigned argIndex =
55 operands.getProducedOperandCount() + operandIndex - operandsStart;
56 return successor->getArgument(argIndex);
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62 const SuccessorOperands &operands) {
63 // Check the count.
64 unsigned operandCount = operands.size();
65 Block *destBB = op->getSuccessor(succNo);
66 if (operandCount != destBB->getNumArguments())
67 return op->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB->getNumArguments();
72 // Check the types.
73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74 ++i) {
75 if (!cast<BranchOpInterface>(op).areTypesCompatible(
76 operands[i].getType(), destBB->getArgument(i).getType()))
77 return op->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo;
80 return success();
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
88 RegionBranchPoint sourceNo,
89 RegionBranchPoint succRegionNo) {
90 diag << "from ";
91 if (Region *region = sourceNo.getRegionOrNull())
92 diag << "Region #" << region->getRegionNumber();
93 else
94 diag << "parent operands";
96 diag << " to ";
97 if (Region *region = succRegionNo.getRegionOrNull())
98 diag << "Region #" << region->getRegionNumber();
99 else
100 diag << "parent results";
101 return diag;
104 /// Verify that types match along all region control flow edges originating from
105 /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
106 /// types of the inputs that flow to a successor region.
107 static LogicalResult
108 verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
109 function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
110 getInputsTypesForRegion) {
111 auto regionInterface = cast<RegionBranchOpInterface>(op);
113 SmallVector<RegionSuccessor, 2> successors;
114 regionInterface.getSuccessorRegions(sourcePoint, successors);
116 for (RegionSuccessor &succ : successors) {
117 FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
118 if (failed(sourceTypes))
119 return failure();
121 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
122 if (sourceTypes->size() != succInputsTypes.size()) {
123 InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
124 return printRegionEdgeName(diag, sourcePoint, succ)
125 << ": source has " << sourceTypes->size()
126 << " operands, but target successor needs "
127 << succInputsTypes.size();
130 for (const auto &typesIdx :
131 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
132 Type sourceType = std::get<0>(typesIdx.value());
133 Type inputType = std::get<1>(typesIdx.value());
134 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
135 InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
136 return printRegionEdgeName(diag, sourcePoint, succ)
137 << ": source type #" << typesIdx.index() << " " << sourceType
138 << " should match input type #" << typesIdx.index() << " "
139 << inputType;
143 return success();
146 /// Verify that types match along control flow edges described the given op.
147 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
148 auto regionInterface = cast<RegionBranchOpInterface>(op);
150 auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
151 return regionInterface.getEntrySuccessorOperands(point).getTypes();
154 // Verify types along control flow edges originating from the parent.
155 if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
156 inputTypesFromParent)))
157 return failure();
159 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
160 if (lhs.size() != rhs.size())
161 return false;
162 for (auto types : llvm::zip(lhs, rhs)) {
163 if (!regionInterface.areTypesCompatible(std::get<0>(types),
164 std::get<1>(types))) {
165 return false;
168 return true;
171 // Verify types along control flow edges originating from each region.
172 for (Region &region : op->getRegions()) {
174 // Since there can be multiple terminators implementing the
175 // `RegionBranchTerminatorOpInterface`, all should have the same operand
176 // types when passing them to the same region.
178 SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
179 for (Block &block : region)
180 if (!block.empty())
181 if (auto terminator =
182 dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
183 regionReturnOps.push_back(terminator);
185 // If there is no return-like terminator, the op itself should verify
186 // type consistency.
187 if (regionReturnOps.empty())
188 continue;
190 auto inputTypesForRegion =
191 [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
192 std::optional<OperandRange> regionReturnOperands;
193 for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
194 auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
196 if (!regionReturnOperands) {
197 regionReturnOperands = terminatorOperands;
198 continue;
201 // Found more than one ReturnLike terminator. Make sure the operand
202 // types match with the first one.
203 if (!areTypesCompatible(regionReturnOperands->getTypes(),
204 terminatorOperands.getTypes())) {
205 InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
206 return printRegionEdgeName(diag, region, point)
207 << " operands mismatch between return-like terminators";
211 // All successors get the same set of operand types.
212 return TypeRange(regionReturnOperands->getTypes());
215 if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
216 return failure();
219 return success();
222 /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
223 /// this function returns "true" for a successor region. The first parameter is
224 /// the successor region. The second parameter indicates all already visited
225 /// regions.
226 using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
228 /// Traverse the region graph starting at `begin`. The traversal is interrupted
229 /// if `stopCondition` evaluates to "true" for a successor region. In that case,
230 /// this function returns "true". Otherwise, if the traversal was not
231 /// interrupted, this function returns "false".
232 static bool traverseRegionGraph(Region *begin,
233 StopConditionFn stopConditionFn) {
234 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
235 SmallVector<bool> visited(op->getNumRegions(), false);
236 visited[begin->getRegionNumber()] = true;
238 // Retrieve all successors of the region and enqueue them in the worklist.
239 SmallVector<Region *> worklist;
240 auto enqueueAllSuccessors = [&](Region *region) {
241 SmallVector<RegionSuccessor> successors;
242 op.getSuccessorRegions(region, successors);
243 for (RegionSuccessor successor : successors)
244 if (!successor.isParent())
245 worklist.push_back(successor.getSuccessor());
247 enqueueAllSuccessors(begin);
249 // Process all regions in the worklist via DFS.
250 while (!worklist.empty()) {
251 Region *nextRegion = worklist.pop_back_val();
252 if (stopConditionFn(nextRegion, visited))
253 return true;
254 if (visited[nextRegion->getRegionNumber()])
255 continue;
256 visited[nextRegion->getRegionNumber()] = true;
257 enqueueAllSuccessors(nextRegion);
260 return false;
263 /// Return `true` if region `r` is reachable from region `begin` according to
264 /// the RegionBranchOpInterface (by taking a branch).
265 static bool isRegionReachable(Region *begin, Region *r) {
266 assert(begin->getParentOp() == r->getParentOp() &&
267 "expected that both regions belong to the same op");
268 return traverseRegionGraph(begin,
269 [&](Region *nextRegion, ArrayRef<bool> visited) {
270 // Interrupt traversal if `r` was reached.
271 return nextRegion == r;
275 /// Return `true` if `a` and `b` are in mutually exclusive regions.
277 /// 1. Find the first common of `a` and `b` (ancestor) that implements
278 /// RegionBranchOpInterface.
279 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
280 /// contained.
281 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
282 /// mutually exclusive if they are not reachable from each other as per
283 /// RegionBranchOpInterface::getSuccessorRegions.
284 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
285 assert(a && "expected non-empty operation");
286 assert(b && "expected non-empty operation");
288 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
289 while (branchOp) {
290 // Check if b is inside branchOp. (We already know that a is.)
291 if (!branchOp->isProperAncestor(b)) {
292 // Check next enclosing RegionBranchOpInterface.
293 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
294 continue;
297 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
298 // are contained.
299 Region *regionA = nullptr, *regionB = nullptr;
300 for (Region &r : branchOp->getRegions()) {
301 if (r.findAncestorOpInRegion(*a)) {
302 assert(!regionA && "already found a region for a");
303 regionA = &r;
305 if (r.findAncestorOpInRegion(*b)) {
306 assert(!regionB && "already found a region for b");
307 regionB = &r;
310 assert(regionA && regionB && "could not find region of op");
312 // `a` and `b` are in mutually exclusive regions if both regions are
313 // distinct and neither region is reachable from the other region.
314 return regionA != regionB && !isRegionReachable(regionA, regionB) &&
315 !isRegionReachable(regionB, regionA);
318 // Could not find a common RegionBranchOpInterface among a's and b's
319 // ancestors.
320 return false;
323 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
324 Region *region = &getOperation()->getRegion(index);
325 return isRegionReachable(region, region);
328 bool RegionBranchOpInterface::hasLoop() {
329 SmallVector<RegionSuccessor> entryRegions;
330 getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
331 for (RegionSuccessor successor : entryRegions)
332 if (!successor.isParent() &&
333 traverseRegionGraph(successor.getSuccessor(),
334 [](Region *nextRegion, ArrayRef<bool> visited) {
335 // Interrupt traversal if the region was already
336 // visited.
337 return visited[nextRegion->getRegionNumber()];
339 return true;
340 return false;
343 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
344 while (Region *region = op->getParentRegion()) {
345 op = region->getParentOp();
346 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
347 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
348 return region;
350 return nullptr;
353 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
354 Region *region = value.getParentRegion();
355 while (region) {
356 Operation *op = region->getParentOp();
357 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
358 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
359 return region;
360 region = op->getParentRegion();
362 return nullptr;