1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlowFramework.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Region.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/IR/Value.h"
17 #include "mlir/IR/ValueRange.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "mlir/Interfaces/ControlFlowInterfaces.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Casting.h"
28 using namespace mlir::dataflow
;
30 //===----------------------------------------------------------------------===//
31 // AbstractSparseLattice
32 //===----------------------------------------------------------------------===//
34 void AbstractSparseLattice::onUpdate(DataFlowSolver
*solver
) const {
35 AnalysisState::onUpdate(solver
);
37 // Push all users of the value to the queue.
38 for (Operation
*user
: point
.get
<Value
>().getUsers())
39 for (DataFlowAnalysis
*analysis
: useDefSubscribers
)
40 solver
->enqueue({user
, analysis
});
43 //===----------------------------------------------------------------------===//
44 // AbstractSparseForwardDataFlowAnalysis
45 //===----------------------------------------------------------------------===//
47 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
48 DataFlowSolver
&solver
)
49 : DataFlowAnalysis(solver
) {
50 registerPointKind
<CFGEdge
>();
54 AbstractSparseForwardDataFlowAnalysis::initialize(Operation
*top
) {
55 // Mark the entry block arguments as having reached their pessimistic
57 for (Region
®ion
: top
->getRegions()) {
60 for (Value argument
: region
.front().getArguments())
61 setToEntryState(getLatticeElement(argument
));
64 return initializeRecursively(top
);
68 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation
*op
) {
69 // Initialize the analysis by visiting every owner of an SSA value (all
70 // operations and blocks).
72 for (Region
®ion
: op
->getRegions()) {
73 for (Block
&block
: region
) {
74 getOrCreate
<Executable
>(&block
)->blockContentSubscribe(this);
76 for (Operation
&op
: block
)
77 if (failed(initializeRecursively(&op
)))
85 LogicalResult
AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point
) {
86 if (Operation
*op
= llvm::dyn_cast_if_present
<Operation
*>(point
))
88 else if (Block
*block
= llvm::dyn_cast_if_present
<Block
*>(point
))
95 void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation
*op
) {
96 // Exit early on operations with no results.
97 if (op
->getNumResults() == 0)
100 // If the containing block is not executable, bail out.
101 if (!getOrCreate
<Executable
>(op
->getBlock())->isLive())
104 // Get the result lattices.
105 SmallVector
<AbstractSparseLattice
*> resultLattices
;
106 resultLattices
.reserve(op
->getNumResults());
107 for (Value result
: op
->getResults()) {
108 AbstractSparseLattice
*resultLattice
= getLatticeElement(result
);
109 resultLattices
.push_back(resultLattice
);
112 // The results of a region branch operation are determined by control-flow.
113 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
)) {
114 return visitRegionSuccessors({branch
}, branch
,
115 /*successor=*/RegionBranchPoint::parent(),
119 // The results of a call operation are determined by the callgraph.
120 if (auto call
= dyn_cast
<CallOpInterface
>(op
)) {
121 const auto *predecessors
= getOrCreateFor
<PredecessorState
>(op
, call
);
122 // If not all return sites are known, then conservatively assume we can't
123 // reason about the data-flow.
124 if (!predecessors
->allPredecessorsKnown())
125 return setAllToEntryStates(resultLattices
);
126 for (Operation
*predecessor
: predecessors
->getKnownPredecessors())
127 for (auto it
: llvm::zip(predecessor
->getOperands(), resultLattices
))
128 join(std::get
<1>(it
), *getLatticeElementFor(op
, std::get
<0>(it
)));
132 // Grab the lattice elements of the operands.
133 SmallVector
<const AbstractSparseLattice
*> operandLattices
;
134 operandLattices
.reserve(op
->getNumOperands());
135 for (Value operand
: op
->getOperands()) {
136 AbstractSparseLattice
*operandLattice
= getLatticeElement(operand
);
137 operandLattice
->useDefSubscribe(this);
138 operandLattices
.push_back(operandLattice
);
141 // Invoke the operation transfer function.
142 visitOperationImpl(op
, operandLattices
, resultLattices
);
145 void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block
*block
) {
146 // Exit early on blocks with no arguments.
147 if (block
->getNumArguments() == 0)
150 // If the block is not executable, bail out.
151 if (!getOrCreate
<Executable
>(block
)->isLive())
154 // Get the argument lattices.
155 SmallVector
<AbstractSparseLattice
*> argLattices
;
156 argLattices
.reserve(block
->getNumArguments());
157 for (BlockArgument argument
: block
->getArguments()) {
158 AbstractSparseLattice
*argLattice
= getLatticeElement(argument
);
159 argLattices
.push_back(argLattice
);
162 // The argument lattices of entry blocks are set by region control-flow or the
164 if (block
->isEntryBlock()) {
165 // Check if this block is the entry block of a callable region.
166 auto callable
= dyn_cast
<CallableOpInterface
>(block
->getParentOp());
167 if (callable
&& callable
.getCallableRegion() == block
->getParent()) {
168 const auto *callsites
= getOrCreateFor
<PredecessorState
>(block
, callable
);
169 // If not all callsites are known, conservatively mark all lattices as
170 // having reached their pessimistic fixpoints.
171 if (!callsites
->allPredecessorsKnown())
172 return setAllToEntryStates(argLattices
);
173 for (Operation
*callsite
: callsites
->getKnownPredecessors()) {
174 auto call
= cast
<CallOpInterface
>(callsite
);
175 for (auto it
: llvm::zip(call
.getArgOperands(), argLattices
))
176 join(std::get
<1>(it
), *getLatticeElementFor(block
, std::get
<0>(it
)));
181 // Check if the lattices can be determined from region control flow.
182 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(block
->getParentOp())) {
183 return visitRegionSuccessors(block
, branch
, block
->getParent(),
187 // Otherwise, we can't reason about the data-flow.
188 return visitNonControlFlowArgumentsImpl(block
->getParentOp(),
189 RegionSuccessor(block
->getParent()),
190 argLattices
, /*firstIndex=*/0);
193 // Iterate over the predecessors of the non-entry block.
194 for (Block::pred_iterator it
= block
->pred_begin(), e
= block
->pred_end();
196 Block
*predecessor
= *it
;
198 // If the edge from the predecessor block to the current block is not live,
200 auto *edgeExecutable
=
201 getOrCreate
<Executable
>(getProgramPoint
<CFGEdge
>(predecessor
, block
));
202 edgeExecutable
->blockContentSubscribe(this);
203 if (!edgeExecutable
->isLive())
206 // Check if we can reason about the data-flow from the predecessor.
208 dyn_cast
<BranchOpInterface
>(predecessor
->getTerminator())) {
209 SuccessorOperands operands
=
210 branch
.getSuccessorOperands(it
.getSuccessorIndex());
211 for (auto [idx
, lattice
] : llvm::enumerate(argLattices
)) {
212 if (Value operand
= operands
[idx
]) {
213 join(lattice
, *getLatticeElementFor(block
, operand
));
215 // Conservatively consider internally produced arguments as entry
217 setAllToEntryStates(lattice
);
221 return setAllToEntryStates(argLattices
);
226 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
227 ProgramPoint point
, RegionBranchOpInterface branch
,
228 RegionBranchPoint successor
, ArrayRef
<AbstractSparseLattice
*> lattices
) {
229 const auto *predecessors
= getOrCreateFor
<PredecessorState
>(point
, point
);
230 assert(predecessors
->allPredecessorsKnown() &&
231 "unexpected unresolved region successors");
233 for (Operation
*op
: predecessors
->getKnownPredecessors()) {
234 // Get the incoming successor operands.
235 std::optional
<OperandRange
> operands
;
237 // Check if the predecessor is the parent op.
239 operands
= branch
.getEntrySuccessorOperands(successor
);
240 // Otherwise, try to deduce the operands from a region return-like op.
241 } else if (auto regionTerminator
=
242 dyn_cast
<RegionBranchTerminatorOpInterface
>(op
)) {
243 operands
= regionTerminator
.getSuccessorOperands(successor
);
247 // We can't reason about the data-flow.
248 return setAllToEntryStates(lattices
);
251 ValueRange inputs
= predecessors
->getSuccessorInputs(op
);
252 assert(inputs
.size() == operands
->size() &&
253 "expected the same number of successor inputs as operands");
255 unsigned firstIndex
= 0;
256 if (inputs
.size() != lattices
.size()) {
257 if (llvm::dyn_cast_if_present
<Operation
*>(point
)) {
259 firstIndex
= cast
<OpResult
>(inputs
.front()).getResultNumber();
260 visitNonControlFlowArgumentsImpl(
263 branch
->getResults().slice(firstIndex
, inputs
.size())),
264 lattices
, firstIndex
);
267 firstIndex
= cast
<BlockArgument
>(inputs
.front()).getArgNumber();
268 Region
*region
= point
.get
<Block
*>()->getParent();
269 visitNonControlFlowArgumentsImpl(
271 RegionSuccessor(region
, region
->getArguments().slice(
272 firstIndex
, inputs
.size())),
273 lattices
, firstIndex
);
277 for (auto it
: llvm::zip(*operands
, lattices
.drop_front(firstIndex
)))
278 join(std::get
<1>(it
), *getLatticeElementFor(point
, std::get
<0>(it
)));
282 const AbstractSparseLattice
*
283 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point
,
285 AbstractSparseLattice
*state
= getLatticeElement(value
);
286 addDependency(state
, point
);
290 void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
291 ArrayRef
<AbstractSparseLattice
*> lattices
) {
292 for (AbstractSparseLattice
*lattice
: lattices
)
293 setToEntryState(lattice
);
296 void AbstractSparseForwardDataFlowAnalysis::join(
297 AbstractSparseLattice
*lhs
, const AbstractSparseLattice
&rhs
) {
298 propagateIfChanged(lhs
, lhs
->join(rhs
));
301 //===----------------------------------------------------------------------===//
302 // AbstractSparseBackwardDataFlowAnalysis
303 //===----------------------------------------------------------------------===//
305 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
306 DataFlowSolver
&solver
, SymbolTableCollection
&symbolTable
)
307 : DataFlowAnalysis(solver
), symbolTable(symbolTable
) {
308 registerPointKind
<CFGEdge
>();
312 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation
*top
) {
313 return initializeRecursively(top
);
317 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation
*op
) {
319 for (Region
®ion
: op
->getRegions()) {
320 for (Block
&block
: region
) {
321 getOrCreate
<Executable
>(&block
)->blockContentSubscribe(this);
322 // Initialize ops in reverse order, so we can do as much initial
323 // propagation as possible without having to go through the
325 for (auto it
= block
.rbegin(); it
!= block
.rend(); it
++)
326 if (failed(initializeRecursively(&*it
)))
334 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point
) {
335 if (Operation
*op
= llvm::dyn_cast_if_present
<Operation
*>(point
))
337 else if (llvm::dyn_cast_if_present
<Block
*>(point
))
338 // For backward dataflow, we don't have to do any work for the blocks
339 // themselves. CFG edges between blocks are processed by the BranchOp
340 // logic in `visitOperation`, and entry blocks for functions are tied
341 // to the CallOp arguments by visitOperation.
348 SmallVector
<AbstractSparseLattice
*>
349 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values
) {
350 SmallVector
<AbstractSparseLattice
*> resultLattices
;
351 resultLattices
.reserve(values
.size());
352 for (Value result
: values
) {
353 AbstractSparseLattice
*resultLattice
= getLatticeElement(result
);
354 resultLattices
.push_back(resultLattice
);
356 return resultLattices
;
359 SmallVector
<const AbstractSparseLattice
*>
360 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
361 ProgramPoint point
, ValueRange values
) {
362 SmallVector
<const AbstractSparseLattice
*> resultLattices
;
363 resultLattices
.reserve(values
.size());
364 for (Value result
: values
) {
365 const AbstractSparseLattice
*resultLattice
=
366 getLatticeElementFor(point
, result
);
367 resultLattices
.push_back(resultLattice
);
369 return resultLattices
;
372 static MutableArrayRef
<OpOperand
> operandsToOpOperands(OperandRange
&operands
) {
373 return MutableArrayRef
<OpOperand
>(operands
.getBase(), operands
.size());
376 void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation
*op
) {
377 // If we're in a dead block, bail out.
378 if (!getOrCreate
<Executable
>(op
->getBlock())->isLive())
381 SmallVector
<AbstractSparseLattice
*> operandLattices
=
382 getLatticeElements(op
->getOperands());
383 SmallVector
<const AbstractSparseLattice
*> resultLattices
=
384 getLatticeElementsFor(op
, op
->getResults());
386 // Block arguments of region branch operations flow back into the operands
388 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
)) {
389 visitRegionSuccessors(branch
, operandLattices
);
393 if (auto branch
= dyn_cast
<BranchOpInterface
>(op
)) {
394 // Block arguments of successor blocks flow back into our operands.
396 // We remember all operands not forwarded to any block in a BitVector.
397 // We can't just cut out a range here, since the non-forwarded ops might
398 // be non-contiguous (if there's more than one successor).
399 BitVector
unaccounted(op
->getNumOperands(), true);
401 for (auto [index
, block
] : llvm::enumerate(op
->getSuccessors())) {
402 SuccessorOperands successorOperands
= branch
.getSuccessorOperands(index
);
403 OperandRange forwarded
= successorOperands
.getForwardedOperands();
404 if (!forwarded
.empty()) {
405 MutableArrayRef
<OpOperand
> operands
= op
->getOpOperands().slice(
406 forwarded
.getBeginOperandIndex(), forwarded
.size());
407 for (OpOperand
&operand
: operands
) {
408 unaccounted
.reset(operand
.getOperandNumber());
409 if (std::optional
<BlockArgument
> blockArg
=
410 detail::getBranchSuccessorArgument(
411 successorOperands
, operand
.getOperandNumber(), block
)) {
412 meet(getLatticeElement(operand
.get()),
413 *getLatticeElementFor(op
, *blockArg
));
418 // Operands not forwarded to successor blocks are typically parameters
419 // of the branch operation itself (for example the boolean for if/else).
420 for (int index
: unaccounted
.set_bits()) {
421 OpOperand
&operand
= op
->getOpOperand(index
);
422 visitBranchOperand(operand
);
427 // For function calls, connect the arguments of the entry blocks to the
428 // operands of the call op that are forwarded to these arguments.
429 if (auto call
= dyn_cast
<CallOpInterface
>(op
)) {
430 Operation
*callableOp
= call
.resolveCallable(&symbolTable
);
431 if (auto callable
= dyn_cast_or_null
<CallableOpInterface
>(callableOp
)) {
432 // Not all operands of a call op forward to arguments. Such operands are
433 // stored in `unaccounted`.
434 BitVector
unaccounted(op
->getNumOperands(), true);
436 OperandRange argOperands
= call
.getArgOperands();
437 MutableArrayRef
<OpOperand
> argOpOperands
=
438 operandsToOpOperands(argOperands
);
439 Region
*region
= callable
.getCallableRegion();
440 if (region
&& !region
->empty()) {
441 Block
&block
= region
->front();
442 for (auto [blockArg
, argOpOperand
] :
443 llvm::zip(block
.getArguments(), argOpOperands
)) {
444 meet(getLatticeElement(argOpOperand
.get()),
445 *getLatticeElementFor(op
, blockArg
));
446 unaccounted
.reset(argOpOperand
.getOperandNumber());
449 // Handle the operands of the call op that aren't forwarded to any
451 for (int index
: unaccounted
.set_bits()) {
452 OpOperand
&opOperand
= op
->getOpOperand(index
);
453 visitCallOperand(opOperand
);
459 // When the region of an op implementing `RegionBranchOpInterface` has a
460 // terminator implementing `RegionBranchTerminatorOpInterface` or a
461 // return-like terminator, the region's successors' arguments flow back into
462 // the "successor operands" of this terminator.
464 // A successor operand with respect to an op implementing
465 // `RegionBranchOpInterface` is an operand that is forwarded to a region
466 // successor's input. There are two types of successor operands: the operands
467 // of this op itself and the operands of the terminators of the regions of
469 if (auto terminator
= dyn_cast
<RegionBranchTerminatorOpInterface
>(op
)) {
470 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
->getParentOp())) {
471 visitRegionSuccessorsFromTerminator(terminator
, branch
);
476 if (op
->hasTrait
<OpTrait::ReturnLike
>()) {
477 // Going backwards, the operands of the return are derived from the
478 // results of all CallOps calling this CallableOp.
479 if (auto callable
= dyn_cast
<CallableOpInterface
>(op
->getParentOp())) {
480 const PredecessorState
*callsites
=
481 getOrCreateFor
<PredecessorState
>(op
, callable
);
482 if (callsites
->allPredecessorsKnown()) {
483 for (Operation
*call
: callsites
->getKnownPredecessors()) {
484 SmallVector
<const AbstractSparseLattice
*> callResultLattices
=
485 getLatticeElementsFor(op
, call
->getResults());
486 for (auto [op
, result
] :
487 llvm::zip(operandLattices
, callResultLattices
))
491 // If we don't know all the callers, we can't know where the
492 // returned values go. Note that, in particular, this will trigger
493 // for the return ops of any public functions.
494 setAllToExitStates(operandLattices
);
500 visitOperationImpl(op
, operandLattices
, resultLattices
);
503 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
504 RegionBranchOpInterface branch
,
505 ArrayRef
<AbstractSparseLattice
*> operandLattices
) {
506 Operation
*op
= branch
.getOperation();
507 SmallVector
<RegionSuccessor
> successors
;
508 SmallVector
<Attribute
> operands(op
->getNumOperands(), nullptr);
509 branch
.getEntrySuccessorRegions(operands
, successors
);
511 // All operands not forwarded to any successor. This set can be non-contiguous
512 // in the presence of multiple successors.
513 BitVector
unaccounted(op
->getNumOperands(), true);
515 for (RegionSuccessor
&successor
: successors
) {
516 OperandRange operands
= branch
.getEntrySuccessorOperands(successor
);
517 MutableArrayRef
<OpOperand
> opoperands
= operandsToOpOperands(operands
);
518 ValueRange inputs
= successor
.getSuccessorInputs();
519 for (auto [operand
, input
] : llvm::zip(opoperands
, inputs
)) {
520 meet(getLatticeElement(operand
.get()), *getLatticeElementFor(op
, input
));
521 unaccounted
.reset(operand
.getOperandNumber());
524 // All operands not forwarded to regions are typically parameters of the
525 // branch operation itself (for example the boolean for if/else).
526 for (int index
: unaccounted
.set_bits()) {
527 visitBranchOperand(op
->getOpOperand(index
));
531 void AbstractSparseBackwardDataFlowAnalysis::
532 visitRegionSuccessorsFromTerminator(
533 RegionBranchTerminatorOpInterface terminator
,
534 RegionBranchOpInterface branch
) {
535 assert(isa
<RegionBranchTerminatorOpInterface
>(terminator
) &&
536 "expected a `RegionBranchTerminatorOpInterface` op");
537 assert(terminator
->getParentOp() == branch
.getOperation() &&
538 "expected `branch` to be the parent op of `terminator`");
540 SmallVector
<Attribute
> operandAttributes(terminator
->getNumOperands(),
542 SmallVector
<RegionSuccessor
> successors
;
543 terminator
.getSuccessorRegions(operandAttributes
, successors
);
544 // All operands not forwarded to any successor. This set can be
545 // non-contiguous in the presence of multiple successors.
546 BitVector
unaccounted(terminator
->getNumOperands(), true);
548 for (const RegionSuccessor
&successor
: successors
) {
549 ValueRange inputs
= successor
.getSuccessorInputs();
550 OperandRange operands
= terminator
.getSuccessorOperands(successor
);
551 MutableArrayRef
<OpOperand
> opOperands
= operandsToOpOperands(operands
);
552 for (auto [opOperand
, input
] : llvm::zip(opOperands
, inputs
)) {
553 meet(getLatticeElement(opOperand
.get()),
554 *getLatticeElementFor(terminator
, input
));
555 unaccounted
.reset(const_cast<OpOperand
&>(opOperand
).getOperandNumber());
558 // Visit operands of the branch op not forwarded to the next region.
559 // (Like e.g. the boolean of `scf.conditional`)
560 for (int index
: unaccounted
.set_bits()) {
561 visitBranchOperand(terminator
->getOpOperand(index
));
565 const AbstractSparseLattice
*
566 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point
,
568 AbstractSparseLattice
*state
= getLatticeElement(value
);
569 addDependency(state
, point
);
573 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
574 ArrayRef
<AbstractSparseLattice
*> lattices
) {
575 for (AbstractSparseLattice
*lattice
: lattices
)
576 setToExitState(lattice
);
579 void AbstractSparseBackwardDataFlowAnalysis::meet(
580 AbstractSparseLattice
*lhs
, const AbstractSparseLattice
&rhs
) {
581 propagateIfChanged(lhs
, lhs
->meet(rhs
));