[AArch64][SME2] Add multi-vector saturating doubling multiply high intrinsics
[llvm-project.git] / mlir / lib / Interfaces / ControlFlowInterfaces.cpp
blobdeec6058c1a0cdc92a03d63fd5329bfbedabbc08
1 //===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/Interfaces/ControlFlowInterfaces.h"
13 #include "llvm/ADT/SmallPtrSet.h"
15 using namespace mlir;
17 //===----------------------------------------------------------------------===//
18 // ControlFlowInterfaces
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
23 SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
24 : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
27 SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
28 MutableOperandRange forwardedOperands)
29 : producedOperandCount(producedOperandCount),
30 forwardedOperands(std::move(forwardedOperands)) {}
32 //===----------------------------------------------------------------------===//
33 // BranchOpInterface
34 //===----------------------------------------------------------------------===//
36 /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
37 /// successor if 'operandIndex' is within the range of 'operands', or
38 /// std::nullopt if `operandIndex` isn't a successor operand index.
39 std::optional<BlockArgument>
40 detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
41 unsigned operandIndex, Block *successor) {
42 OperandRange forwardedOperands = operands.getForwardedOperands();
43 // Check that the operands are valid.
44 if (forwardedOperands.empty())
45 return std::nullopt;
47 // Check to ensure that this operand is within the range.
48 unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
49 if (operandIndex < operandsStart ||
50 operandIndex >= (operandsStart + forwardedOperands.size()))
51 return std::nullopt;
53 // Index the successor.
54 unsigned argIndex =
55 operands.getProducedOperandCount() + operandIndex - operandsStart;
56 return successor->getArgument(argIndex);
59 /// Verify that the given operands match those of the given successor block.
60 LogicalResult
61 detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
62 const SuccessorOperands &operands) {
63 // Check the count.
64 unsigned operandCount = operands.size();
65 Block *destBB = op->getSuccessor(succNo);
66 if (operandCount != destBB->getNumArguments())
67 return op->emitError() << "branch has " << operandCount
68 << " operands for successor #" << succNo
69 << ", but target block has "
70 << destBB->getNumArguments();
72 // Check the types.
73 for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
74 ++i) {
75 if (!cast<BranchOpInterface>(op).areTypesCompatible(
76 operands[i].getType(), destBB->getArgument(i).getType()))
77 return op->emitError() << "type mismatch for bb argument #" << i
78 << " of successor #" << succNo;
80 return success();
83 //===----------------------------------------------------------------------===//
84 // RegionBranchOpInterface
85 //===----------------------------------------------------------------------===//
87 /// Verify that types match along all region control flow edges originating from
88 /// `sourceNo` (region # if source is a region, std::nullopt if source is parent
89 /// op). `getInputsTypesForRegion` is a function that returns the types of the
90 /// inputs that flow from `sourceIndex' to the given region, or std::nullopt if
91 /// the exact type match verification is not necessary (e.g., if the Op verifies
92 /// the match itself).
93 static LogicalResult verifyTypesAlongAllEdges(
94 Operation *op, std::optional<unsigned> sourceNo,
95 function_ref<std::optional<TypeRange>(std::optional<unsigned>)>
96 getInputsTypesForRegion) {
97 auto regionInterface = cast<RegionBranchOpInterface>(op);
99 SmallVector<RegionSuccessor, 2> successors;
100 regionInterface.getSuccessorRegions(sourceNo, successors);
102 for (RegionSuccessor &succ : successors) {
103 std::optional<unsigned> succRegionNo;
104 if (!succ.isParent())
105 succRegionNo = succ.getSuccessor()->getRegionNumber();
107 auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
108 diag << "from ";
109 if (sourceNo)
110 diag << "Region #" << sourceNo.value();
111 else
112 diag << "parent operands";
114 diag << " to ";
115 if (succRegionNo)
116 diag << "Region #" << succRegionNo.value();
117 else
118 diag << "parent results";
119 return diag;
122 std::optional<TypeRange> sourceTypes =
123 getInputsTypesForRegion(succRegionNo);
124 if (!sourceTypes.has_value())
125 continue;
127 TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
128 if (sourceTypes->size() != succInputsTypes.size()) {
129 InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
130 return printEdgeName(diag) << ": source has " << sourceTypes->size()
131 << " operands, but target successor needs "
132 << succInputsTypes.size();
135 for (const auto &typesIdx :
136 llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
137 Type sourceType = std::get<0>(typesIdx.value());
138 Type inputType = std::get<1>(typesIdx.value());
139 if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
140 InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
141 return printEdgeName(diag)
142 << ": source type #" << typesIdx.index() << " " << sourceType
143 << " should match input type #" << typesIdx.index() << " "
144 << inputType;
148 return success();
151 /// Verify that types match along control flow edges described the given op.
152 LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
153 auto regionInterface = cast<RegionBranchOpInterface>(op);
155 auto inputTypesFromParent =
156 [&](std::optional<unsigned> regionNo) -> TypeRange {
157 return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
160 // Verify types along control flow edges originating from the parent.
161 if (failed(verifyTypesAlongAllEdges(op, std::nullopt, inputTypesFromParent)))
162 return failure();
164 auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
165 if (lhs.size() != rhs.size())
166 return false;
167 for (auto types : llvm::zip(lhs, rhs)) {
168 if (!regionInterface.areTypesCompatible(std::get<0>(types),
169 std::get<1>(types))) {
170 return false;
173 return true;
176 // Verify types along control flow edges originating from each region.
177 for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
178 Region &region = op->getRegion(regionNo);
180 // Since there can be multiple `ReturnLike` terminators or others
181 // implementing the `RegionBranchTerminatorOpInterface`, all should have the
182 // same operand types when passing them to the same region.
184 std::optional<OperandRange> regionReturnOperands;
185 for (Block &block : region) {
186 Operation *terminator = block.getTerminator();
187 auto terminatorOperands =
188 getRegionBranchSuccessorOperands(terminator, regionNo);
189 if (!terminatorOperands)
190 continue;
192 if (!regionReturnOperands) {
193 regionReturnOperands = terminatorOperands;
194 continue;
197 // Found more than one ReturnLike terminator. Make sure the operand types
198 // match with the first one.
199 if (!areTypesCompatible(regionReturnOperands->getTypes(),
200 terminatorOperands->getTypes()))
201 return op->emitOpError("Region #")
202 << regionNo
203 << " operands mismatch between return-like terminators";
206 auto inputTypesFromRegion =
207 [&](std::optional<unsigned> regionNo) -> std::optional<TypeRange> {
208 // If there is no return-like terminator, the op itself should verify
209 // type consistency.
210 if (!regionReturnOperands)
211 return std::nullopt;
213 // All successors get the same set of operand types.
214 return TypeRange(regionReturnOperands->getTypes());
217 if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
218 return failure();
221 return success();
224 /// Return `true` if region `r` is reachable from region `begin` according to
225 /// the RegionBranchOpInterface (by taking a branch).
226 static bool isRegionReachable(Region *begin, Region *r) {
227 assert(begin->getParentOp() == r->getParentOp() &&
228 "expected that both regions belong to the same op");
229 auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
230 SmallVector<bool> visited(op->getNumRegions(), false);
231 visited[begin->getRegionNumber()] = true;
233 // Retrieve all successors of the region and enqueue them in the worklist.
234 SmallVector<unsigned> worklist;
235 auto enqueueAllSuccessors = [&](unsigned index) {
236 SmallVector<RegionSuccessor> successors;
237 op.getSuccessorRegions(index, successors);
238 for (RegionSuccessor successor : successors)
239 if (!successor.isParent())
240 worklist.push_back(successor.getSuccessor()->getRegionNumber());
242 enqueueAllSuccessors(begin->getRegionNumber());
244 // Process all regions in the worklist via DFS.
245 while (!worklist.empty()) {
246 unsigned nextRegion = worklist.pop_back_val();
247 if (nextRegion == r->getRegionNumber())
248 return true;
249 if (visited[nextRegion])
250 continue;
251 visited[nextRegion] = true;
252 enqueueAllSuccessors(nextRegion);
255 return false;
258 /// Return `true` if `a` and `b` are in mutually exclusive regions.
260 /// 1. Find the first common of `a` and `b` (ancestor) that implements
261 /// RegionBranchOpInterface.
262 /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
263 /// contained.
264 /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
265 /// mutually exclusive if they are not reachable from each other as per
266 /// RegionBranchOpInterface::getSuccessorRegions.
267 bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
268 assert(a && "expected non-empty operation");
269 assert(b && "expected non-empty operation");
271 auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
272 while (branchOp) {
273 // Check if b is inside branchOp. (We already know that a is.)
274 if (!branchOp->isProperAncestor(b)) {
275 // Check next enclosing RegionBranchOpInterface.
276 branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
277 continue;
280 // b is contained in branchOp. Retrieve the regions in which `a` and `b`
281 // are contained.
282 Region *regionA = nullptr, *regionB = nullptr;
283 for (Region &r : branchOp->getRegions()) {
284 if (r.findAncestorOpInRegion(*a)) {
285 assert(!regionA && "already found a region for a");
286 regionA = &r;
288 if (r.findAncestorOpInRegion(*b)) {
289 assert(!regionB && "already found a region for b");
290 regionB = &r;
293 assert(regionA && regionB && "could not find region of op");
295 // `a` and `b` are in mutually exclusive regions if both regions are
296 // distinct and neither region is reachable from the other region.
297 return regionA != regionB && !isRegionReachable(regionA, regionB) &&
298 !isRegionReachable(regionB, regionA);
301 // Could not find a common RegionBranchOpInterface among a's and b's
302 // ancestors.
303 return false;
306 bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
307 Region *region = &getOperation()->getRegion(index);
308 return isRegionReachable(region, region);
311 void RegionBranchOpInterface::getSuccessorRegions(
312 std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
313 unsigned numInputs = 0;
314 if (index) {
315 // If the predecessor is a region, get the number of operands from an
316 // exiting terminator in the region.
317 for (Block &block : getOperation()->getRegion(*index)) {
318 Operation *terminator = block.getTerminator();
319 if (getRegionBranchSuccessorOperands(terminator, *index)) {
320 numInputs = terminator->getNumOperands();
321 break;
324 } else {
325 // Otherwise, use the number of parent operation operands.
326 numInputs = getOperation()->getNumOperands();
328 SmallVector<Attribute, 2> operands(numInputs, nullptr);
329 getSuccessorRegions(index, operands, regions);
332 Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
333 while (Region *region = op->getParentRegion()) {
334 op = region->getParentOp();
335 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
336 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
337 return region;
339 return nullptr;
342 Region *mlir::getEnclosingRepetitiveRegion(Value value) {
343 Region *region = value.getParentRegion();
344 while (region) {
345 Operation *op = region->getParentOp();
346 if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
347 if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
348 return region;
349 region = op->getParentRegion();
351 return nullptr;
354 //===----------------------------------------------------------------------===//
355 // RegionBranchTerminatorOpInterface
356 //===----------------------------------------------------------------------===//
358 /// Returns true if the given operation is either annotated with the
359 /// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`.
360 bool mlir::isRegionReturnLike(Operation *operation) {
361 return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
362 operation->hasTrait<OpTrait::ReturnLike>();
365 /// Returns the mutable operands that are passed to the region with the given
366 /// `regionIndex`. If the operation does not implement the
367 /// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the
368 /// result will be `std::nullopt`. In all other cases, the resulting
369 /// `OperandRange` represents all operands that are passed to the specified
370 /// successor region. If `regionIndex` is `std::nullopt`, all operands that are
371 /// passed to the parent operation will be returned.
372 std::optional<MutableOperandRange>
373 mlir::getMutableRegionBranchSuccessorOperands(
374 Operation *operation, std::optional<unsigned> regionIndex) {
375 // Try to query a RegionBranchTerminatorOpInterface to determine
376 // all successor operands that will be passed to the successor
377 // input arguments.
378 if (auto regionTerminatorInterface =
379 dyn_cast<RegionBranchTerminatorOpInterface>(operation))
380 return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
382 // TODO: The ReturnLike trait should imply a default implementation of the
383 // RegionBranchTerminatorOpInterface. This would make this code significantly
384 // easier. Furthermore, this may even make this function obsolete.
385 if (operation->hasTrait<OpTrait::ReturnLike>())
386 return MutableOperandRange(operation);
387 return std::nullopt;
390 /// Returns the read only operands that are passed to the region with the given
391 /// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more
392 /// information.
393 std::optional<OperandRange>
394 mlir::getRegionBranchSuccessorOperands(Operation *operation,
395 std::optional<unsigned> regionIndex) {
396 auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
397 if (range)
398 return range->operator OperandRange();
399 return std::nullopt;