1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
15 //===----------------------------------------------------------------------===//
17 #include "mlir/Transforms/Passes.h"
19 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
20 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/Interfaces/SideEffectInterfaces.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Transforms/FoldUtils.h"
28 #define GEN_PASS_DEF_SCCP
29 #include "mlir/Transforms/Passes.h.inc"
33 using namespace mlir::dataflow
;
35 //===----------------------------------------------------------------------===//
37 //===----------------------------------------------------------------------===//
39 /// Replace the given value with a constant if the corresponding lattice
40 /// represents a constant. Returns success if the value was replaced, failure
42 static LogicalResult
replaceWithConstant(DataFlowSolver
&solver
,
44 OperationFolder
&folder
, Value value
) {
45 auto *lattice
= solver
.lookupState
<Lattice
<ConstantValue
>>(value
);
46 if (!lattice
|| lattice
->getValue().isUninitialized())
48 const ConstantValue
&latticeValue
= lattice
->getValue();
49 if (!latticeValue
.getConstantValue())
52 // Attempt to materialize a constant for the given value.
53 Dialect
*dialect
= latticeValue
.getConstantDialect();
54 Value constant
= folder
.getOrCreateConstant(
55 builder
.getInsertionBlock(), dialect
, latticeValue
.getConstantValue(),
60 value
.replaceAllUsesWith(constant
);
64 /// Rewrite the given regions using the computing analysis. This replaces the
65 /// uses of all values that have been computed to be constant, and erases as
66 /// many newly dead operations.
67 static void rewrite(DataFlowSolver
&solver
, MLIRContext
*context
,
68 MutableArrayRef
<Region
> initialRegions
) {
69 SmallVector
<Block
*> worklist
;
70 auto addToWorklist
= [&](MutableArrayRef
<Region
> regions
) {
71 for (Region
®ion
: regions
)
72 for (Block
&block
: llvm::reverse(region
))
73 worklist
.push_back(&block
);
76 // An operation folder used to create and unique constants.
77 OperationFolder
folder(context
);
78 OpBuilder
builder(context
);
80 addToWorklist(initialRegions
);
81 while (!worklist
.empty()) {
82 Block
*block
= worklist
.pop_back_val();
84 for (Operation
&op
: llvm::make_early_inc_range(*block
)) {
85 builder
.setInsertionPoint(&op
);
87 // Replace any result with constants.
88 bool replacedAll
= op
.getNumResults() != 0;
89 for (Value res
: op
.getResults())
91 succeeded(replaceWithConstant(solver
, builder
, folder
, res
));
93 // If all of the results of the operation were replaced, try to erase
94 // the operation completely.
95 if (replacedAll
&& wouldOpBeTriviallyDead(&op
)) {
96 assert(op
.use_empty() && "expected all uses to be replaced");
101 // Add any the regions of this operation to the worklist.
102 addToWorklist(op
.getRegions());
105 // Replace any block arguments with constants.
106 builder
.setInsertionPointToStart(block
);
107 for (BlockArgument arg
: block
->getArguments())
108 (void)replaceWithConstant(solver
, builder
, folder
, arg
);
112 //===----------------------------------------------------------------------===//
114 //===----------------------------------------------------------------------===//
117 struct SCCP
: public impl::SCCPBase
<SCCP
> {
118 void runOnOperation() override
;
122 void SCCP::runOnOperation() {
123 Operation
*op
= getOperation();
125 DataFlowSolver solver
;
126 solver
.load
<DeadCodeAnalysis
>();
127 solver
.load
<SparseConstantPropagation
>();
128 if (failed(solver
.initializeAndRun(op
)))
129 return signalPassFailure();
130 rewrite(solver
, op
->getContext(), op
->getRegions());
133 std::unique_ptr
<Pass
> mlir::createSCCPPass() {
134 return std::make_unique
<SCCP
>();