[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / lib / Analysis / DataFlow / SparseAnalysis.cpp
blob9f544d656df92568ede68697ce2a21e7ce73a2ae
1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlowFramework.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Region.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/IR/Value.h"
17 #include "mlir/IR/ValueRange.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "mlir/Interfaces/ControlFlowInterfaces.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Casting.h"
24 #include <cassert>
25 #include <optional>
27 using namespace mlir;
28 using namespace mlir::dataflow;
30 //===----------------------------------------------------------------------===//
31 // AbstractSparseLattice
32 //===----------------------------------------------------------------------===//
34 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
35 AnalysisState::onUpdate(solver);
37 // Push all users of the value to the queue.
38 for (Operation *user : point.get<Value>().getUsers())
39 for (DataFlowAnalysis *analysis : useDefSubscribers)
40 solver->enqueue({user, analysis});
43 //===----------------------------------------------------------------------===//
44 // AbstractSparseForwardDataFlowAnalysis
45 //===----------------------------------------------------------------------===//
47 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
48 DataFlowSolver &solver)
49 : DataFlowAnalysis(solver) {
50 registerPointKind<CFGEdge>();
53 LogicalResult
54 AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
55 // Mark the entry block arguments as having reached their pessimistic
56 // fixpoints.
57 for (Region &region : top->getRegions()) {
58 if (region.empty())
59 continue;
60 for (Value argument : region.front().getArguments())
61 setToEntryState(getLatticeElement(argument));
64 return initializeRecursively(top);
67 LogicalResult
68 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
69 // Initialize the analysis by visiting every owner of an SSA value (all
70 // operations and blocks).
71 visitOperation(op);
72 for (Region &region : op->getRegions()) {
73 for (Block &block : region) {
74 getOrCreate<Executable>(&block)->blockContentSubscribe(this);
75 visitBlock(&block);
76 for (Operation &op : block)
77 if (failed(initializeRecursively(&op)))
78 return failure();
82 return success();
85 LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
86 if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
87 visitOperation(op);
88 else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
89 visitBlock(block);
90 else
91 return failure();
92 return success();
95 void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
96 // Exit early on operations with no results.
97 if (op->getNumResults() == 0)
98 return;
100 // If the containing block is not executable, bail out.
101 if (!getOrCreate<Executable>(op->getBlock())->isLive())
102 return;
104 // Get the result lattices.
105 SmallVector<AbstractSparseLattice *> resultLattices;
106 resultLattices.reserve(op->getNumResults());
107 for (Value result : op->getResults()) {
108 AbstractSparseLattice *resultLattice = getLatticeElement(result);
109 resultLattices.push_back(resultLattice);
112 // The results of a region branch operation are determined by control-flow.
113 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
114 return visitRegionSuccessors({branch}, branch,
115 /*successor=*/RegionBranchPoint::parent(),
116 resultLattices);
119 // The results of a call operation are determined by the callgraph.
120 if (auto call = dyn_cast<CallOpInterface>(op)) {
121 const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
122 // If not all return sites are known, then conservatively assume we can't
123 // reason about the data-flow.
124 if (!predecessors->allPredecessorsKnown())
125 return setAllToEntryStates(resultLattices);
126 for (Operation *predecessor : predecessors->getKnownPredecessors())
127 for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
128 join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
129 return;
132 // Grab the lattice elements of the operands.
133 SmallVector<const AbstractSparseLattice *> operandLattices;
134 operandLattices.reserve(op->getNumOperands());
135 for (Value operand : op->getOperands()) {
136 AbstractSparseLattice *operandLattice = getLatticeElement(operand);
137 operandLattice->useDefSubscribe(this);
138 operandLattices.push_back(operandLattice);
141 // Invoke the operation transfer function.
142 visitOperationImpl(op, operandLattices, resultLattices);
145 void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
146 // Exit early on blocks with no arguments.
147 if (block->getNumArguments() == 0)
148 return;
150 // If the block is not executable, bail out.
151 if (!getOrCreate<Executable>(block)->isLive())
152 return;
154 // Get the argument lattices.
155 SmallVector<AbstractSparseLattice *> argLattices;
156 argLattices.reserve(block->getNumArguments());
157 for (BlockArgument argument : block->getArguments()) {
158 AbstractSparseLattice *argLattice = getLatticeElement(argument);
159 argLattices.push_back(argLattice);
162 // The argument lattices of entry blocks are set by region control-flow or the
163 // callgraph.
164 if (block->isEntryBlock()) {
165 // Check if this block is the entry block of a callable region.
166 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
167 if (callable && callable.getCallableRegion() == block->getParent()) {
168 const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
169 // If not all callsites are known, conservatively mark all lattices as
170 // having reached their pessimistic fixpoints.
171 if (!callsites->allPredecessorsKnown())
172 return setAllToEntryStates(argLattices);
173 for (Operation *callsite : callsites->getKnownPredecessors()) {
174 auto call = cast<CallOpInterface>(callsite);
175 for (auto it : llvm::zip(call.getArgOperands(), argLattices))
176 join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
178 return;
181 // Check if the lattices can be determined from region control flow.
182 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
183 return visitRegionSuccessors(block, branch, block->getParent(),
184 argLattices);
187 // Otherwise, we can't reason about the data-flow.
188 return visitNonControlFlowArgumentsImpl(block->getParentOp(),
189 RegionSuccessor(block->getParent()),
190 argLattices, /*firstIndex=*/0);
193 // Iterate over the predecessors of the non-entry block.
194 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
195 it != e; ++it) {
196 Block *predecessor = *it;
198 // If the edge from the predecessor block to the current block is not live,
199 // bail out.
200 auto *edgeExecutable =
201 getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
202 edgeExecutable->blockContentSubscribe(this);
203 if (!edgeExecutable->isLive())
204 continue;
206 // Check if we can reason about the data-flow from the predecessor.
207 if (auto branch =
208 dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
209 SuccessorOperands operands =
210 branch.getSuccessorOperands(it.getSuccessorIndex());
211 for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
212 if (Value operand = operands[idx]) {
213 join(lattice, *getLatticeElementFor(block, operand));
214 } else {
215 // Conservatively consider internally produced arguments as entry
216 // points.
217 setAllToEntryStates(lattice);
220 } else {
221 return setAllToEntryStates(argLattices);
226 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
227 ProgramPoint point, RegionBranchOpInterface branch,
228 RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
229 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
230 assert(predecessors->allPredecessorsKnown() &&
231 "unexpected unresolved region successors");
233 for (Operation *op : predecessors->getKnownPredecessors()) {
234 // Get the incoming successor operands.
235 std::optional<OperandRange> operands;
237 // Check if the predecessor is the parent op.
238 if (op == branch) {
239 operands = branch.getEntrySuccessorOperands(successor);
240 // Otherwise, try to deduce the operands from a region return-like op.
241 } else if (auto regionTerminator =
242 dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
243 operands = regionTerminator.getSuccessorOperands(successor);
246 if (!operands) {
247 // We can't reason about the data-flow.
248 return setAllToEntryStates(lattices);
251 ValueRange inputs = predecessors->getSuccessorInputs(op);
252 assert(inputs.size() == operands->size() &&
253 "expected the same number of successor inputs as operands");
255 unsigned firstIndex = 0;
256 if (inputs.size() != lattices.size()) {
257 if (llvm::dyn_cast_if_present<Operation *>(point)) {
258 if (!inputs.empty())
259 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
260 visitNonControlFlowArgumentsImpl(
261 branch,
262 RegionSuccessor(
263 branch->getResults().slice(firstIndex, inputs.size())),
264 lattices, firstIndex);
265 } else {
266 if (!inputs.empty())
267 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
268 Region *region = point.get<Block *>()->getParent();
269 visitNonControlFlowArgumentsImpl(
270 branch,
271 RegionSuccessor(region, region->getArguments().slice(
272 firstIndex, inputs.size())),
273 lattices, firstIndex);
277 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
278 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
282 const AbstractSparseLattice *
283 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
284 Value value) {
285 AbstractSparseLattice *state = getLatticeElement(value);
286 addDependency(state, point);
287 return state;
290 void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
291 ArrayRef<AbstractSparseLattice *> lattices) {
292 for (AbstractSparseLattice *lattice : lattices)
293 setToEntryState(lattice);
296 void AbstractSparseForwardDataFlowAnalysis::join(
297 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
298 propagateIfChanged(lhs, lhs->join(rhs));
301 //===----------------------------------------------------------------------===//
302 // AbstractSparseBackwardDataFlowAnalysis
303 //===----------------------------------------------------------------------===//
305 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
306 DataFlowSolver &solver, SymbolTableCollection &symbolTable)
307 : DataFlowAnalysis(solver), symbolTable(symbolTable) {
308 registerPointKind<CFGEdge>();
311 LogicalResult
312 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation *top) {
313 return initializeRecursively(top);
316 LogicalResult
317 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
318 visitOperation(op);
319 for (Region &region : op->getRegions()) {
320 for (Block &block : region) {
321 getOrCreate<Executable>(&block)->blockContentSubscribe(this);
322 // Initialize ops in reverse order, so we can do as much initial
323 // propagation as possible without having to go through the
324 // solver queue.
325 for (auto it = block.rbegin(); it != block.rend(); it++)
326 if (failed(initializeRecursively(&*it)))
327 return failure();
330 return success();
333 LogicalResult
334 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
335 if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
336 visitOperation(op);
337 else if (llvm::dyn_cast_if_present<Block *>(point))
338 // For backward dataflow, we don't have to do any work for the blocks
339 // themselves. CFG edges between blocks are processed by the BranchOp
340 // logic in `visitOperation`, and entry blocks for functions are tied
341 // to the CallOp arguments by visitOperation.
342 return success();
343 else
344 return failure();
345 return success();
348 SmallVector<AbstractSparseLattice *>
349 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values) {
350 SmallVector<AbstractSparseLattice *> resultLattices;
351 resultLattices.reserve(values.size());
352 for (Value result : values) {
353 AbstractSparseLattice *resultLattice = getLatticeElement(result);
354 resultLattices.push_back(resultLattice);
356 return resultLattices;
359 SmallVector<const AbstractSparseLattice *>
360 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
361 ProgramPoint point, ValueRange values) {
362 SmallVector<const AbstractSparseLattice *> resultLattices;
363 resultLattices.reserve(values.size());
364 for (Value result : values) {
365 const AbstractSparseLattice *resultLattice =
366 getLatticeElementFor(point, result);
367 resultLattices.push_back(resultLattice);
369 return resultLattices;
372 static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
373 return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
376 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
377 // If we're in a dead block, bail out.
378 if (!getOrCreate<Executable>(op->getBlock())->isLive())
379 return;
381 SmallVector<AbstractSparseLattice *> operandLattices =
382 getLatticeElements(op->getOperands());
383 SmallVector<const AbstractSparseLattice *> resultLattices =
384 getLatticeElementsFor(op, op->getResults());
386 // Block arguments of region branch operations flow back into the operands
387 // of the parent op
388 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
389 visitRegionSuccessors(branch, operandLattices);
390 return;
393 if (auto branch = dyn_cast<BranchOpInterface>(op)) {
394 // Block arguments of successor blocks flow back into our operands.
396 // We remember all operands not forwarded to any block in a BitVector.
397 // We can't just cut out a range here, since the non-forwarded ops might
398 // be non-contiguous (if there's more than one successor).
399 BitVector unaccounted(op->getNumOperands(), true);
401 for (auto [index, block] : llvm::enumerate(op->getSuccessors())) {
402 SuccessorOperands successorOperands = branch.getSuccessorOperands(index);
403 OperandRange forwarded = successorOperands.getForwardedOperands();
404 if (!forwarded.empty()) {
405 MutableArrayRef<OpOperand> operands = op->getOpOperands().slice(
406 forwarded.getBeginOperandIndex(), forwarded.size());
407 for (OpOperand &operand : operands) {
408 unaccounted.reset(operand.getOperandNumber());
409 if (std::optional<BlockArgument> blockArg =
410 detail::getBranchSuccessorArgument(
411 successorOperands, operand.getOperandNumber(), block)) {
412 meet(getLatticeElement(operand.get()),
413 *getLatticeElementFor(op, *blockArg));
418 // Operands not forwarded to successor blocks are typically parameters
419 // of the branch operation itself (for example the boolean for if/else).
420 for (int index : unaccounted.set_bits()) {
421 OpOperand &operand = op->getOpOperand(index);
422 visitBranchOperand(operand);
424 return;
427 // For function calls, connect the arguments of the entry blocks to the
428 // operands of the call op that are forwarded to these arguments.
429 if (auto call = dyn_cast<CallOpInterface>(op)) {
430 Operation *callableOp = call.resolveCallable(&symbolTable);
431 if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
432 // Not all operands of a call op forward to arguments. Such operands are
433 // stored in `unaccounted`.
434 BitVector unaccounted(op->getNumOperands(), true);
436 OperandRange argOperands = call.getArgOperands();
437 MutableArrayRef<OpOperand> argOpOperands =
438 operandsToOpOperands(argOperands);
439 Region *region = callable.getCallableRegion();
440 if (region && !region->empty()) {
441 Block &block = region->front();
442 for (auto [blockArg, argOpOperand] :
443 llvm::zip(block.getArguments(), argOpOperands)) {
444 meet(getLatticeElement(argOpOperand.get()),
445 *getLatticeElementFor(op, blockArg));
446 unaccounted.reset(argOpOperand.getOperandNumber());
449 // Handle the operands of the call op that aren't forwarded to any
450 // arguments.
451 for (int index : unaccounted.set_bits()) {
452 OpOperand &opOperand = op->getOpOperand(index);
453 visitCallOperand(opOperand);
455 return;
459 // When the region of an op implementing `RegionBranchOpInterface` has a
460 // terminator implementing `RegionBranchTerminatorOpInterface` or a
461 // return-like terminator, the region's successors' arguments flow back into
462 // the "successor operands" of this terminator.
464 // A successor operand with respect to an op implementing
465 // `RegionBranchOpInterface` is an operand that is forwarded to a region
466 // successor's input. There are two types of successor operands: the operands
467 // of this op itself and the operands of the terminators of the regions of
468 // this op.
469 if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
470 if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
471 visitRegionSuccessorsFromTerminator(terminator, branch);
472 return;
476 if (op->hasTrait<OpTrait::ReturnLike>()) {
477 // Going backwards, the operands of the return are derived from the
478 // results of all CallOps calling this CallableOp.
479 if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
480 const PredecessorState *callsites =
481 getOrCreateFor<PredecessorState>(op, callable);
482 if (callsites->allPredecessorsKnown()) {
483 for (Operation *call : callsites->getKnownPredecessors()) {
484 SmallVector<const AbstractSparseLattice *> callResultLattices =
485 getLatticeElementsFor(op, call->getResults());
486 for (auto [op, result] :
487 llvm::zip(operandLattices, callResultLattices))
488 meet(op, *result);
490 } else {
491 // If we don't know all the callers, we can't know where the
492 // returned values go. Note that, in particular, this will trigger
493 // for the return ops of any public functions.
494 setAllToExitStates(operandLattices);
496 return;
500 visitOperationImpl(op, operandLattices, resultLattices);
503 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
504 RegionBranchOpInterface branch,
505 ArrayRef<AbstractSparseLattice *> operandLattices) {
506 Operation *op = branch.getOperation();
507 SmallVector<RegionSuccessor> successors;
508 SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
509 branch.getEntrySuccessorRegions(operands, successors);
511 // All operands not forwarded to any successor. This set can be non-contiguous
512 // in the presence of multiple successors.
513 BitVector unaccounted(op->getNumOperands(), true);
515 for (RegionSuccessor &successor : successors) {
516 OperandRange operands = branch.getEntrySuccessorOperands(successor);
517 MutableArrayRef<OpOperand> opoperands = operandsToOpOperands(operands);
518 ValueRange inputs = successor.getSuccessorInputs();
519 for (auto [operand, input] : llvm::zip(opoperands, inputs)) {
520 meet(getLatticeElement(operand.get()), *getLatticeElementFor(op, input));
521 unaccounted.reset(operand.getOperandNumber());
524 // All operands not forwarded to regions are typically parameters of the
525 // branch operation itself (for example the boolean for if/else).
526 for (int index : unaccounted.set_bits()) {
527 visitBranchOperand(op->getOpOperand(index));
531 void AbstractSparseBackwardDataFlowAnalysis::
532 visitRegionSuccessorsFromTerminator(
533 RegionBranchTerminatorOpInterface terminator,
534 RegionBranchOpInterface branch) {
535 assert(isa<RegionBranchTerminatorOpInterface>(terminator) &&
536 "expected a `RegionBranchTerminatorOpInterface` op");
537 assert(terminator->getParentOp() == branch.getOperation() &&
538 "expected `branch` to be the parent op of `terminator`");
540 SmallVector<Attribute> operandAttributes(terminator->getNumOperands(),
541 nullptr);
542 SmallVector<RegionSuccessor> successors;
543 terminator.getSuccessorRegions(operandAttributes, successors);
544 // All operands not forwarded to any successor. This set can be
545 // non-contiguous in the presence of multiple successors.
546 BitVector unaccounted(terminator->getNumOperands(), true);
548 for (const RegionSuccessor &successor : successors) {
549 ValueRange inputs = successor.getSuccessorInputs();
550 OperandRange operands = terminator.getSuccessorOperands(successor);
551 MutableArrayRef<OpOperand> opOperands = operandsToOpOperands(operands);
552 for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) {
553 meet(getLatticeElement(opOperand.get()),
554 *getLatticeElementFor(terminator, input));
555 unaccounted.reset(const_cast<OpOperand &>(opOperand).getOperandNumber());
558 // Visit operands of the branch op not forwarded to the next region.
559 // (Like e.g. the boolean of `scf.conditional`)
560 for (int index : unaccounted.set_bits()) {
561 visitBranchOperand(terminator->getOpOperand(index));
565 const AbstractSparseLattice *
566 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
567 Value value) {
568 AbstractSparseLattice *state = getLatticeElement(value);
569 addDependency(state, point);
570 return state;
573 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
574 ArrayRef<AbstractSparseLattice *> lattices) {
575 for (AbstractSparseLattice *lattice : lattices)
576 setToExitState(lattice);
579 void AbstractSparseBackwardDataFlowAnalysis::meet(
580 AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs) {
581 propagateIfChanged(lhs, lhs->meet(rhs));