1 //===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
11 #include "mlir/Analysis/DataFlowFramework.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Region.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/IR/Value.h"
17 #include "mlir/IR/ValueRange.h"
18 #include "mlir/Interfaces/CallInterfaces.h"
19 #include "mlir/Interfaces/ControlFlowInterfaces.h"
20 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Casting.h"
27 using namespace mlir::dataflow
;
29 //===----------------------------------------------------------------------===//
30 // AbstractSparseLattice
31 //===----------------------------------------------------------------------===//
33 void AbstractSparseLattice::onUpdate(DataFlowSolver
*solver
) const {
34 AnalysisState::onUpdate(solver
);
36 // Push all users of the value to the queue.
37 for (Operation
*user
: anchor
.get
<Value
>().getUsers())
38 for (DataFlowAnalysis
*analysis
: useDefSubscribers
)
39 solver
->enqueue({solver
->getProgramPointAfter(user
), analysis
});
42 //===----------------------------------------------------------------------===//
43 // AbstractSparseForwardDataFlowAnalysis
44 //===----------------------------------------------------------------------===//
46 AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
47 DataFlowSolver
&solver
)
48 : DataFlowAnalysis(solver
) {
49 registerAnchorKind
<CFGEdge
>();
53 AbstractSparseForwardDataFlowAnalysis::initialize(Operation
*top
) {
54 // Mark the entry block arguments as having reached their pessimistic
56 for (Region
®ion
: top
->getRegions()) {
59 for (Value argument
: region
.front().getArguments())
60 setToEntryState(getLatticeElement(argument
));
63 return initializeRecursively(top
);
67 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation
*op
) {
68 // Initialize the analysis by visiting every owner of an SSA value (all
69 // operations and blocks).
70 if (failed(visitOperation(op
)))
73 for (Region
®ion
: op
->getRegions()) {
74 for (Block
&block
: region
) {
75 getOrCreate
<Executable
>(getProgramPointBefore(&block
))
76 ->blockContentSubscribe(this);
78 for (Operation
&op
: block
)
79 if (failed(initializeRecursively(&op
)))
88 AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint
*point
) {
89 if (!point
->isBlockStart())
90 return visitOperation(point
->getPrevOp());
91 visitBlock(point
->getBlock());
96 AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation
*op
) {
97 // Exit early on operations with no results.
98 if (op
->getNumResults() == 0)
101 // If the containing block is not executable, bail out.
102 if (op
->getBlock() != nullptr &&
103 !getOrCreate
<Executable
>(getProgramPointBefore(op
->getBlock()))->isLive())
106 // Get the result lattices.
107 SmallVector
<AbstractSparseLattice
*> resultLattices
;
108 resultLattices
.reserve(op
->getNumResults());
109 for (Value result
: op
->getResults()) {
110 AbstractSparseLattice
*resultLattice
= getLatticeElement(result
);
111 resultLattices
.push_back(resultLattice
);
114 // The results of a region branch operation are determined by control-flow.
115 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
)) {
116 visitRegionSuccessors(getProgramPointAfter(branch
), branch
,
117 /*successor=*/RegionBranchPoint::parent(),
122 // Grab the lattice elements of the operands.
123 SmallVector
<const AbstractSparseLattice
*> operandLattices
;
124 operandLattices
.reserve(op
->getNumOperands());
125 for (Value operand
: op
->getOperands()) {
126 AbstractSparseLattice
*operandLattice
= getLatticeElement(operand
);
127 operandLattice
->useDefSubscribe(this);
128 operandLattices
.push_back(operandLattice
);
131 if (auto call
= dyn_cast
<CallOpInterface
>(op
)) {
132 // If the call operation is to an external function, attempt to infer the
133 // results from the call arguments.
135 dyn_cast_if_present
<CallableOpInterface
>(call
.resolveCallable());
136 if (!getSolverConfig().isInterprocedural() ||
137 (callable
&& !callable
.getCallableRegion())) {
138 visitExternalCallImpl(call
, operandLattices
, resultLattices
);
142 // Otherwise, the results of a call operation are determined by the
144 const auto *predecessors
= getOrCreateFor
<PredecessorState
>(
145 getProgramPointAfter(op
), getProgramPointAfter(call
));
146 // If not all return sites are known, then conservatively assume we can't
147 // reason about the data-flow.
148 if (!predecessors
->allPredecessorsKnown()) {
149 setAllToEntryStates(resultLattices
);
152 for (Operation
*predecessor
: predecessors
->getKnownPredecessors())
153 for (auto &&[operand
, resLattice
] :
154 llvm::zip(predecessor
->getOperands(), resultLattices
))
156 *getLatticeElementFor(getProgramPointAfter(op
), operand
));
160 // Invoke the operation transfer function.
161 return visitOperationImpl(op
, operandLattices
, resultLattices
);
164 void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block
*block
) {
165 // Exit early on blocks with no arguments.
166 if (block
->getNumArguments() == 0)
169 // If the block is not executable, bail out.
170 if (!getOrCreate
<Executable
>(getProgramPointBefore(block
))->isLive())
173 // Get the argument lattices.
174 SmallVector
<AbstractSparseLattice
*> argLattices
;
175 argLattices
.reserve(block
->getNumArguments());
176 for (BlockArgument argument
: block
->getArguments()) {
177 AbstractSparseLattice
*argLattice
= getLatticeElement(argument
);
178 argLattices
.push_back(argLattice
);
181 // The argument lattices of entry blocks are set by region control-flow or the
183 if (block
->isEntryBlock()) {
184 // Check if this block is the entry block of a callable region.
185 auto callable
= dyn_cast
<CallableOpInterface
>(block
->getParentOp());
186 if (callable
&& callable
.getCallableRegion() == block
->getParent()) {
187 const auto *callsites
= getOrCreateFor
<PredecessorState
>(
188 getProgramPointBefore(block
), getProgramPointAfter(callable
));
189 // If not all callsites are known, conservatively mark all lattices as
190 // having reached their pessimistic fixpoints.
191 if (!callsites
->allPredecessorsKnown() ||
192 !getSolverConfig().isInterprocedural()) {
193 return setAllToEntryStates(argLattices
);
195 for (Operation
*callsite
: callsites
->getKnownPredecessors()) {
196 auto call
= cast
<CallOpInterface
>(callsite
);
197 for (auto it
: llvm::zip(call
.getArgOperands(), argLattices
))
198 join(std::get
<1>(it
),
199 *getLatticeElementFor(getProgramPointBefore(block
),
205 // Check if the lattices can be determined from region control flow.
206 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(block
->getParentOp())) {
207 return visitRegionSuccessors(getProgramPointBefore(block
), branch
,
208 block
->getParent(), argLattices
);
211 // Otherwise, we can't reason about the data-flow.
212 return visitNonControlFlowArgumentsImpl(block
->getParentOp(),
213 RegionSuccessor(block
->getParent()),
214 argLattices
, /*firstIndex=*/0);
217 // Iterate over the predecessors of the non-entry block.
218 for (Block::pred_iterator it
= block
->pred_begin(), e
= block
->pred_end();
220 Block
*predecessor
= *it
;
222 // If the edge from the predecessor block to the current block is not live,
224 auto *edgeExecutable
=
225 getOrCreate
<Executable
>(getLatticeAnchor
<CFGEdge
>(predecessor
, block
));
226 edgeExecutable
->blockContentSubscribe(this);
227 if (!edgeExecutable
->isLive())
230 // Check if we can reason about the data-flow from the predecessor.
232 dyn_cast
<BranchOpInterface
>(predecessor
->getTerminator())) {
233 SuccessorOperands operands
=
234 branch
.getSuccessorOperands(it
.getSuccessorIndex());
235 for (auto [idx
, lattice
] : llvm::enumerate(argLattices
)) {
236 if (Value operand
= operands
[idx
]) {
238 *getLatticeElementFor(getProgramPointBefore(block
), operand
));
240 // Conservatively consider internally produced arguments as entry
242 setAllToEntryStates(lattice
);
246 return setAllToEntryStates(argLattices
);
251 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
252 ProgramPoint
*point
, RegionBranchOpInterface branch
,
253 RegionBranchPoint successor
, ArrayRef
<AbstractSparseLattice
*> lattices
) {
254 const auto *predecessors
= getOrCreateFor
<PredecessorState
>(point
, point
);
255 assert(predecessors
->allPredecessorsKnown() &&
256 "unexpected unresolved region successors");
258 for (Operation
*op
: predecessors
->getKnownPredecessors()) {
259 // Get the incoming successor operands.
260 std::optional
<OperandRange
> operands
;
262 // Check if the predecessor is the parent op.
264 operands
= branch
.getEntrySuccessorOperands(successor
);
265 // Otherwise, try to deduce the operands from a region return-like op.
266 } else if (auto regionTerminator
=
267 dyn_cast
<RegionBranchTerminatorOpInterface
>(op
)) {
268 operands
= regionTerminator
.getSuccessorOperands(successor
);
272 // We can't reason about the data-flow.
273 return setAllToEntryStates(lattices
);
276 ValueRange inputs
= predecessors
->getSuccessorInputs(op
);
277 assert(inputs
.size() == operands
->size() &&
278 "expected the same number of successor inputs as operands");
280 unsigned firstIndex
= 0;
281 if (inputs
.size() != lattices
.size()) {
282 if (!point
->isBlockStart()) {
284 firstIndex
= cast
<OpResult
>(inputs
.front()).getResultNumber();
285 visitNonControlFlowArgumentsImpl(
288 branch
->getResults().slice(firstIndex
, inputs
.size())),
289 lattices
, firstIndex
);
292 firstIndex
= cast
<BlockArgument
>(inputs
.front()).getArgNumber();
293 Region
*region
= point
->getBlock()->getParent();
294 visitNonControlFlowArgumentsImpl(
296 RegionSuccessor(region
, region
->getArguments().slice(
297 firstIndex
, inputs
.size())),
298 lattices
, firstIndex
);
302 for (auto it
: llvm::zip(*operands
, lattices
.drop_front(firstIndex
)))
303 join(std::get
<1>(it
), *getLatticeElementFor(point
, std::get
<0>(it
)));
307 const AbstractSparseLattice
*
308 AbstractSparseForwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint
*point
,
310 AbstractSparseLattice
*state
= getLatticeElement(value
);
311 addDependency(state
, point
);
315 void AbstractSparseForwardDataFlowAnalysis::setAllToEntryStates(
316 ArrayRef
<AbstractSparseLattice
*> lattices
) {
317 for (AbstractSparseLattice
*lattice
: lattices
)
318 setToEntryState(lattice
);
321 void AbstractSparseForwardDataFlowAnalysis::join(
322 AbstractSparseLattice
*lhs
, const AbstractSparseLattice
&rhs
) {
323 propagateIfChanged(lhs
, lhs
->join(rhs
));
326 //===----------------------------------------------------------------------===//
327 // AbstractSparseBackwardDataFlowAnalysis
328 //===----------------------------------------------------------------------===//
330 AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
331 DataFlowSolver
&solver
, SymbolTableCollection
&symbolTable
)
332 : DataFlowAnalysis(solver
), symbolTable(symbolTable
) {
333 registerAnchorKind
<CFGEdge
>();
337 AbstractSparseBackwardDataFlowAnalysis::initialize(Operation
*top
) {
338 return initializeRecursively(top
);
342 AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation
*op
) {
343 if (failed(visitOperation(op
)))
346 for (Region
®ion
: op
->getRegions()) {
347 for (Block
&block
: region
) {
348 getOrCreate
<Executable
>(getProgramPointBefore(&block
))
349 ->blockContentSubscribe(this);
350 // Initialize ops in reverse order, so we can do as much initial
351 // propagation as possible without having to go through the
353 for (auto it
= block
.rbegin(); it
!= block
.rend(); it
++)
354 if (failed(initializeRecursively(&*it
)))
362 AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint
*point
) {
363 // For backward dataflow, we don't have to do any work for the blocks
364 // themselves. CFG edges between blocks are processed by the BranchOp
365 // logic in `visitOperation`, and entry blocks for functions are tied
366 // to the CallOp arguments by visitOperation.
367 if (point
->isBlockStart())
369 return visitOperation(point
->getPrevOp());
372 SmallVector
<AbstractSparseLattice
*>
373 AbstractSparseBackwardDataFlowAnalysis::getLatticeElements(ValueRange values
) {
374 SmallVector
<AbstractSparseLattice
*> resultLattices
;
375 resultLattices
.reserve(values
.size());
376 for (Value result
: values
) {
377 AbstractSparseLattice
*resultLattice
= getLatticeElement(result
);
378 resultLattices
.push_back(resultLattice
);
380 return resultLattices
;
383 SmallVector
<const AbstractSparseLattice
*>
384 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementsFor(
385 ProgramPoint
*point
, ValueRange values
) {
386 SmallVector
<const AbstractSparseLattice
*> resultLattices
;
387 resultLattices
.reserve(values
.size());
388 for (Value result
: values
) {
389 const AbstractSparseLattice
*resultLattice
=
390 getLatticeElementFor(point
, result
);
391 resultLattices
.push_back(resultLattice
);
393 return resultLattices
;
396 static MutableArrayRef
<OpOperand
> operandsToOpOperands(OperandRange
&operands
) {
397 return MutableArrayRef
<OpOperand
>(operands
.getBase(), operands
.size());
401 AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation
*op
) {
402 // If we're in a dead block, bail out.
403 if (op
->getBlock() != nullptr &&
404 !getOrCreate
<Executable
>(getProgramPointBefore(op
->getBlock()))->isLive())
407 SmallVector
<AbstractSparseLattice
*> operandLattices
=
408 getLatticeElements(op
->getOperands());
409 SmallVector
<const AbstractSparseLattice
*> resultLattices
=
410 getLatticeElementsFor(getProgramPointAfter(op
), op
->getResults());
412 // Block arguments of region branch operations flow back into the operands
414 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
)) {
415 visitRegionSuccessors(branch
, operandLattices
);
419 if (auto branch
= dyn_cast
<BranchOpInterface
>(op
)) {
420 // Block arguments of successor blocks flow back into our operands.
422 // We remember all operands not forwarded to any block in a BitVector.
423 // We can't just cut out a range here, since the non-forwarded ops might
424 // be non-contiguous (if there's more than one successor).
425 BitVector
unaccounted(op
->getNumOperands(), true);
427 for (auto [index
, block
] : llvm::enumerate(op
->getSuccessors())) {
428 SuccessorOperands successorOperands
= branch
.getSuccessorOperands(index
);
429 OperandRange forwarded
= successorOperands
.getForwardedOperands();
430 if (!forwarded
.empty()) {
431 MutableArrayRef
<OpOperand
> operands
= op
->getOpOperands().slice(
432 forwarded
.getBeginOperandIndex(), forwarded
.size());
433 for (OpOperand
&operand
: operands
) {
434 unaccounted
.reset(operand
.getOperandNumber());
435 if (std::optional
<BlockArgument
> blockArg
=
436 detail::getBranchSuccessorArgument(
437 successorOperands
, operand
.getOperandNumber(), block
)) {
438 meet(getLatticeElement(operand
.get()),
439 *getLatticeElementFor(getProgramPointAfter(op
), *blockArg
));
444 // Operands not forwarded to successor blocks are typically parameters
445 // of the branch operation itself (for example the boolean for if/else).
446 for (int index
: unaccounted
.set_bits()) {
447 OpOperand
&operand
= op
->getOpOperand(index
);
448 visitBranchOperand(operand
);
453 // For function calls, connect the arguments of the entry blocks to the
454 // operands of the call op that are forwarded to these arguments.
455 if (auto call
= dyn_cast
<CallOpInterface
>(op
)) {
456 Operation
*callableOp
= call
.resolveCallableInTable(&symbolTable
);
457 if (auto callable
= dyn_cast_or_null
<CallableOpInterface
>(callableOp
)) {
458 // Not all operands of a call op forward to arguments. Such operands are
459 // stored in `unaccounted`.
460 BitVector
unaccounted(op
->getNumOperands(), true);
462 // If the call invokes an external function (or a function treated as
463 // external due to config), defer to the corresponding extension hook.
464 // By default, it just does `visitCallOperand` for all operands.
465 OperandRange argOperands
= call
.getArgOperands();
466 MutableArrayRef
<OpOperand
> argOpOperands
=
467 operandsToOpOperands(argOperands
);
468 Region
*region
= callable
.getCallableRegion();
469 if (!region
|| region
->empty() ||
470 !getSolverConfig().isInterprocedural()) {
471 visitExternalCallImpl(call
, operandLattices
, resultLattices
);
475 // Otherwise, propagate information from the entry point of the function
476 // back to operands whenever possible.
477 Block
&block
= region
->front();
478 for (auto [blockArg
, argOpOperand
] :
479 llvm::zip(block
.getArguments(), argOpOperands
)) {
480 meet(getLatticeElement(argOpOperand
.get()),
481 *getLatticeElementFor(getProgramPointAfter(op
), blockArg
));
482 unaccounted
.reset(argOpOperand
.getOperandNumber());
485 // Handle the operands of the call op that aren't forwarded to any
487 for (int index
: unaccounted
.set_bits()) {
488 OpOperand
&opOperand
= op
->getOpOperand(index
);
489 visitCallOperand(opOperand
);
495 // When the region of an op implementing `RegionBranchOpInterface` has a
496 // terminator implementing `RegionBranchTerminatorOpInterface` or a
497 // return-like terminator, the region's successors' arguments flow back into
498 // the "successor operands" of this terminator.
500 // A successor operand with respect to an op implementing
501 // `RegionBranchOpInterface` is an operand that is forwarded to a region
502 // successor's input. There are two types of successor operands: the operands
503 // of this op itself and the operands of the terminators of the regions of
505 if (auto terminator
= dyn_cast
<RegionBranchTerminatorOpInterface
>(op
)) {
506 if (auto branch
= dyn_cast
<RegionBranchOpInterface
>(op
->getParentOp())) {
507 visitRegionSuccessorsFromTerminator(terminator
, branch
);
512 if (op
->hasTrait
<OpTrait::ReturnLike
>()) {
513 // Going backwards, the operands of the return are derived from the
514 // results of all CallOps calling this CallableOp.
515 if (auto callable
= dyn_cast
<CallableOpInterface
>(op
->getParentOp())) {
516 const PredecessorState
*callsites
= getOrCreateFor
<PredecessorState
>(
517 getProgramPointAfter(op
), getProgramPointAfter(callable
));
518 if (callsites
->allPredecessorsKnown()) {
519 for (Operation
*call
: callsites
->getKnownPredecessors()) {
520 SmallVector
<const AbstractSparseLattice
*> callResultLattices
=
521 getLatticeElementsFor(getProgramPointAfter(op
),
523 for (auto [op
, result
] :
524 llvm::zip(operandLattices
, callResultLattices
))
528 // If we don't know all the callers, we can't know where the
529 // returned values go. Note that, in particular, this will trigger
530 // for the return ops of any public functions.
531 setAllToExitStates(operandLattices
);
537 return visitOperationImpl(op
, operandLattices
, resultLattices
);
540 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
541 RegionBranchOpInterface branch
,
542 ArrayRef
<AbstractSparseLattice
*> operandLattices
) {
543 Operation
*op
= branch
.getOperation();
544 SmallVector
<RegionSuccessor
> successors
;
545 SmallVector
<Attribute
> operands(op
->getNumOperands(), nullptr);
546 branch
.getEntrySuccessorRegions(operands
, successors
);
548 // All operands not forwarded to any successor. This set can be non-contiguous
549 // in the presence of multiple successors.
550 BitVector
unaccounted(op
->getNumOperands(), true);
552 for (RegionSuccessor
&successor
: successors
) {
553 OperandRange operands
= branch
.getEntrySuccessorOperands(successor
);
554 MutableArrayRef
<OpOperand
> opoperands
= operandsToOpOperands(operands
);
555 ValueRange inputs
= successor
.getSuccessorInputs();
556 for (auto [operand
, input
] : llvm::zip(opoperands
, inputs
)) {
557 meet(getLatticeElement(operand
.get()),
558 *getLatticeElementFor(getProgramPointAfter(op
), input
));
559 unaccounted
.reset(operand
.getOperandNumber());
562 // All operands not forwarded to regions are typically parameters of the
563 // branch operation itself (for example the boolean for if/else).
564 for (int index
: unaccounted
.set_bits()) {
565 visitBranchOperand(op
->getOpOperand(index
));
569 void AbstractSparseBackwardDataFlowAnalysis::
570 visitRegionSuccessorsFromTerminator(
571 RegionBranchTerminatorOpInterface terminator
,
572 RegionBranchOpInterface branch
) {
573 assert(isa
<RegionBranchTerminatorOpInterface
>(terminator
) &&
574 "expected a `RegionBranchTerminatorOpInterface` op");
575 assert(terminator
->getParentOp() == branch
.getOperation() &&
576 "expected `branch` to be the parent op of `terminator`");
578 SmallVector
<Attribute
> operandAttributes(terminator
->getNumOperands(),
580 SmallVector
<RegionSuccessor
> successors
;
581 terminator
.getSuccessorRegions(operandAttributes
, successors
);
582 // All operands not forwarded to any successor. This set can be
583 // non-contiguous in the presence of multiple successors.
584 BitVector
unaccounted(terminator
->getNumOperands(), true);
586 for (const RegionSuccessor
&successor
: successors
) {
587 ValueRange inputs
= successor
.getSuccessorInputs();
588 OperandRange operands
= terminator
.getSuccessorOperands(successor
);
589 MutableArrayRef
<OpOperand
> opOperands
= operandsToOpOperands(operands
);
590 for (auto [opOperand
, input
] : llvm::zip(opOperands
, inputs
)) {
591 meet(getLatticeElement(opOperand
.get()),
592 *getLatticeElementFor(getProgramPointAfter(terminator
), input
));
593 unaccounted
.reset(const_cast<OpOperand
&>(opOperand
).getOperandNumber());
596 // Visit operands of the branch op not forwarded to the next region.
597 // (Like e.g. the boolean of `scf.conditional`)
598 for (int index
: unaccounted
.set_bits()) {
599 visitBranchOperand(terminator
->getOpOperand(index
));
603 const AbstractSparseLattice
*
604 AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(
605 ProgramPoint
*point
, Value value
) {
606 AbstractSparseLattice
*state
= getLatticeElement(value
);
607 addDependency(state
, point
);
611 void AbstractSparseBackwardDataFlowAnalysis::setAllToExitStates(
612 ArrayRef
<AbstractSparseLattice
*> lattices
) {
613 for (AbstractSparseLattice
*lattice
: lattices
)
614 setToExitState(lattice
);
617 void AbstractSparseBackwardDataFlowAnalysis::meet(
618 AbstractSparseLattice
*lhs
, const AbstractSparseLattice
&rhs
) {
619 propagateIfChanged(lhs
, lhs
->meet(rhs
));