1 //===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
10 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/OpDefinition.h"
13 #include "mlir/IR/Operation.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/Support/Casting.h"
18 #include "llvm/Support/Debug.h"
21 #define DEBUG_TYPE "constant-propagation"
24 using namespace mlir::dataflow
;
26 //===----------------------------------------------------------------------===//
28 //===----------------------------------------------------------------------===//
30 void ConstantValue::print(raw_ostream
&os
) const {
31 if (isUninitialized()) {
32 os
<< "<UNINITIALIZED>";
35 if (getConstantValue() == nullptr) {
39 return getConstantValue().print(os
);
42 //===----------------------------------------------------------------------===//
43 // SparseConstantPropagation
44 //===----------------------------------------------------------------------===//
46 LogicalResult
SparseConstantPropagation::visitOperation(
47 Operation
*op
, ArrayRef
<const Lattice
<ConstantValue
> *> operands
,
48 ArrayRef
<Lattice
<ConstantValue
> *> results
) {
49 LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op
<< "\n");
51 // Don't try to simulate the results of a region operation as we can't
52 // guarantee that folding will be out-of-place. We don't allow in-place
53 // folds as the desire here is for simulated execution, and not general
55 if (op
->getNumRegions()) {
56 setAllToEntryStates(results
);
60 SmallVector
<Attribute
, 8> constantOperands
;
61 constantOperands
.reserve(op
->getNumOperands());
62 for (auto *operandLattice
: operands
) {
63 if (operandLattice
->getValue().isUninitialized())
65 constantOperands
.push_back(operandLattice
->getValue().getConstantValue());
68 // Save the original operands and attributes just in case the operation
69 // folds in-place. The constant passed in may not correspond to the real
70 // runtime value, so in-place updates are not allowed.
71 SmallVector
<Value
, 8> originalOperands(op
->getOperands());
72 DictionaryAttr originalAttrs
= op
->getAttrDictionary();
74 // Simulate the result of folding this operation to a constant. If folding
75 // fails or was an in-place fold, mark the results as overdefined.
76 SmallVector
<OpFoldResult
, 8> foldResults
;
77 foldResults
.reserve(op
->getNumResults());
78 if (failed(op
->fold(constantOperands
, foldResults
))) {
79 setAllToEntryStates(results
);
83 // If the folding was in-place, mark the results as overdefined and reset
84 // the operation. We don't allow in-place folds as the desire here is for
85 // simulated execution, and not general folding.
86 if (foldResults
.empty()) {
87 op
->setOperands(originalOperands
);
88 op
->setAttrs(originalAttrs
);
89 setAllToEntryStates(results
);
93 // Merge the fold results into the lattice for this operation.
94 assert(foldResults
.size() == op
->getNumResults() && "invalid result size");
95 for (const auto it
: llvm::zip(results
, foldResults
)) {
96 Lattice
<ConstantValue
> *lattice
= std::get
<0>(it
);
98 // Merge in the result of the fold, either a constant or a value.
99 OpFoldResult foldResult
= std::get
<1>(it
);
100 if (Attribute attr
= llvm::dyn_cast_if_present
<Attribute
>(foldResult
)) {
101 LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr
<< "\n");
102 propagateIfChanged(lattice
,
103 lattice
->join(ConstantValue(attr
, op
->getDialect())));
105 LLVM_DEBUG(llvm::dbgs()
106 << "Folded to value: " << foldResult
.get
<Value
>() << "\n");
107 AbstractSparseForwardDataFlowAnalysis::join(
108 lattice
, *getLatticeElement(foldResult
.get
<Value
>()));
114 void SparseConstantPropagation::setToEntryState(
115 Lattice
<ConstantValue
> *lattice
) {
116 propagateIfChanged(lattice
,
117 lattice
->join(ConstantValue::getUnknownConstant()));