1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
22 /// Given an operation, find the parent region that folded constants should be
25 getInsertionRegion(DialectInterfaceCollection
<DialectFoldInterface
> &interfaces
,
26 Block
*insertionBlock
) {
27 while (Region
*region
= insertionBlock
->getParent()) {
28 // Insert in this region for any of the following scenarios:
29 // * The parent is unregistered, or is known to be isolated from above.
30 // * The parent is a top-level operation.
31 auto *parentOp
= region
->getParentOp();
32 if (parentOp
->mightHaveTrait
<OpTrait::IsIsolatedFromAbove
>() ||
33 !parentOp
->getBlock())
36 // Otherwise, check if this region is a desired insertion region.
37 auto *interface
= interfaces
.getInterfaceFor(parentOp
);
38 if (LLVM_UNLIKELY(interface
&& interface
->shouldMaterializeInto(region
)))
41 // Traverse up the parent looking for an insertion region.
42 insertionBlock
= parentOp
->getBlock();
44 llvm_unreachable("expected valid insertion region");
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
50 static Operation
*materializeConstant(Dialect
*dialect
, OpBuilder
&builder
,
51 Attribute value
, Type type
,
53 auto insertPt
= builder
.getInsertionPoint();
56 // Ask the dialect to materialize a constant operation for this value.
57 if (auto *constOp
= dialect
->materializeConstant(builder
, value
, type
, loc
)) {
58 assert(insertPt
== builder
.getInsertionPoint());
59 assert(matchPattern(constOp
, m_Constant()));
66 //===----------------------------------------------------------------------===//
68 //===----------------------------------------------------------------------===//
70 LogicalResult
OperationFolder::tryToFold(Operation
*op
, bool *inPlaceUpdate
) {
72 *inPlaceUpdate
= false;
74 // If this is a unique'd constant, return failure as we know that it has
75 // already been folded.
76 if (isFolderOwnedConstant(op
)) {
77 // Check to see if we should rehoist, i.e. if a non-constant operation was
78 // inserted before this one.
79 Block
*opBlock
= op
->getBlock();
80 if (&opBlock
->front() != op
&& !isFolderOwnedConstant(op
->getPrevNode())) {
81 op
->moveBefore(&opBlock
->front());
82 op
->setLoc(erasedFoldedLocation
);
87 // Try to fold the operation.
88 SmallVector
<Value
, 8> results
;
89 if (failed(tryToFold(op
, results
)))
92 // Check to see if the operation was just updated in place.
93 if (results
.empty()) {
95 *inPlaceUpdate
= true;
96 if (auto *rewriteListener
= dyn_cast_if_present
<RewriterBase::Listener
>(
97 rewriter
.getListener())) {
98 // Folding API does not notify listeners, so we have to notify manually.
99 rewriteListener
->notifyOperationModified(op
);
104 // Constant folding succeeded. Replace all of the result values and erase the
107 rewriter
.replaceOp(op
, results
);
111 bool OperationFolder::insertKnownConstant(Operation
*op
, Attribute constValue
) {
112 Block
*opBlock
= op
->getBlock();
114 // If this is a constant we unique'd, we don't need to insert, but we can
115 // check to see if we should rehoist it.
116 if (isFolderOwnedConstant(op
)) {
117 if (&opBlock
->front() != op
&& !isFolderOwnedConstant(op
->getPrevNode())) {
118 op
->moveBefore(&opBlock
->front());
119 op
->setLoc(erasedFoldedLocation
);
124 // Get the constant value of the op if necessary.
126 matchPattern(op
, m_Constant(&constValue
));
127 assert(constValue
&& "expected `op` to be a constant");
129 // Ensure that the provided constant was actually correct.
131 Attribute expectedValue
;
132 matchPattern(op
, m_Constant(&expectedValue
));
134 expectedValue
== constValue
&&
135 "provided constant value was not the expected value of the constant");
139 // Check for an existing constant operation for the attribute value.
140 Region
*insertRegion
= getInsertionRegion(interfaces
, opBlock
);
141 auto &uniquedConstants
= foldScopes
[insertRegion
];
142 Operation
*&folderConstOp
= uniquedConstants
[std::make_tuple(
143 op
->getDialect(), constValue
, *op
->result_type_begin())];
145 // If there is an existing constant, replace `op`.
148 rewriter
.replaceOp(op
, folderConstOp
->getResults());
149 folderConstOp
->setLoc(erasedFoldedLocation
);
153 // Otherwise, we insert `op`. If `op` is in the insertion block and is either
154 // already at the front of the block, or the previous operation is already a
155 // constant we unique'd (i.e. one we inserted), then we don't need to do
156 // anything. Otherwise, we move the constant to the insertion block.
157 Block
*insertBlock
= &insertRegion
->front();
158 if (opBlock
!= insertBlock
|| (&insertBlock
->front() != op
&&
159 !isFolderOwnedConstant(op
->getPrevNode()))) {
160 op
->moveBefore(&insertBlock
->front());
161 op
->setLoc(erasedFoldedLocation
);
165 referencedDialects
[op
].push_back(op
->getDialect());
169 /// Notifies that the given constant `op` should be remove from this
170 /// OperationFolder's internal bookkeeping.
171 void OperationFolder::notifyRemoval(Operation
*op
) {
172 // Check to see if this operation is uniqued within the folder.
173 auto it
= referencedDialects
.find(op
);
174 if (it
== referencedDialects
.end())
177 // Get the constant value for this operation, this is the value that was used
178 // to unique the operation internally.
179 Attribute constValue
;
180 matchPattern(op
, m_Constant(&constValue
));
183 // Get the constant map that this operation was uniqued in.
184 auto &uniquedConstants
=
185 foldScopes
[getInsertionRegion(interfaces
, op
->getBlock())];
187 // Erase all of the references to this operation.
188 auto type
= op
->getResult(0).getType();
189 for (auto *dialect
: it
->second
)
190 uniquedConstants
.erase(std::make_tuple(dialect
, constValue
, type
));
191 referencedDialects
.erase(it
);
194 /// Clear out any constants cached inside of the folder.
195 void OperationFolder::clear() {
197 referencedDialects
.clear();
200 /// Get or create a constant using the given builder. On success this returns
201 /// the constant operation, nullptr otherwise.
202 Value
OperationFolder::getOrCreateConstant(Block
*block
, Dialect
*dialect
,
203 Attribute value
, Type type
) {
204 // Find an insertion point for the constant.
205 auto *insertRegion
= getInsertionRegion(interfaces
, block
);
206 auto &entry
= insertRegion
->front();
207 rewriter
.setInsertionPointToStart(&entry
);
209 // Get the constant map for the insertion region of this operation.
210 // Use erased location since the op is being built at the front of block.
211 auto &uniquedConstants
= foldScopes
[insertRegion
];
212 Operation
*constOp
= tryGetOrCreateConstant(uniquedConstants
, dialect
, value
,
213 type
, erasedFoldedLocation
);
214 return constOp
? constOp
->getResult(0) : Value();
217 bool OperationFolder::isFolderOwnedConstant(Operation
*op
) const {
218 return referencedDialects
.count(op
);
221 /// Tries to perform folding on the given `op`. If successful, populates
222 /// `results` with the results of the folding.
223 LogicalResult
OperationFolder::tryToFold(Operation
*op
,
224 SmallVectorImpl
<Value
> &results
) {
225 SmallVector
<OpFoldResult
, 8> foldResults
;
226 if (failed(op
->fold(foldResults
)) ||
227 failed(processFoldResults(op
, results
, foldResults
)))
233 OperationFolder::processFoldResults(Operation
*op
,
234 SmallVectorImpl
<Value
> &results
,
235 ArrayRef
<OpFoldResult
> foldResults
) {
236 // Check to see if the operation was just updated in place.
237 if (foldResults
.empty())
239 assert(foldResults
.size() == op
->getNumResults());
241 // Create a builder to insert new operations into the entry block of the
243 auto *insertRegion
= getInsertionRegion(interfaces
, op
->getBlock());
244 auto &entry
= insertRegion
->front();
245 rewriter
.setInsertionPointToStart(&entry
);
247 // Get the constant map for the insertion region of this operation.
248 auto &uniquedConstants
= foldScopes
[insertRegion
];
250 // Create the result constants and replace the results.
251 auto *dialect
= op
->getDialect();
252 for (unsigned i
= 0, e
= op
->getNumResults(); i
!= e
; ++i
) {
253 assert(!foldResults
[i
].isNull() && "expected valid OpFoldResult");
255 // Check if the result was an SSA value.
256 if (auto repl
= llvm::dyn_cast_if_present
<Value
>(foldResults
[i
])) {
257 results
.emplace_back(repl
);
261 // Check to see if there is a canonicalized version of this constant.
262 auto res
= op
->getResult(i
);
263 Attribute attrRepl
= foldResults
[i
].get
<Attribute
>();
265 tryGetOrCreateConstant(uniquedConstants
, dialect
, attrRepl
,
266 res
.getType(), erasedFoldedLocation
)) {
267 // Ensure that this constant dominates the operation we are replacing it
268 // with. This may not automatically happen if the operation being folded
269 // was inserted before the constant within the insertion block.
270 Block
*opBlock
= op
->getBlock();
271 if (opBlock
== constOp
->getBlock() && &opBlock
->front() != constOp
)
272 constOp
->moveBefore(&opBlock
->front());
274 results
.push_back(constOp
->getResult(0));
277 // If materialization fails, cleanup any operations generated for the
278 // previous results and return failure.
279 for (Operation
&op
: llvm::make_early_inc_range(
280 llvm::make_range(entry
.begin(), rewriter
.getInsertionPoint()))) {
282 rewriter
.eraseOp(&op
);
292 /// Try to get or create a new constant entry. On success this returns the
293 /// constant operation value, nullptr otherwise.
295 OperationFolder::tryGetOrCreateConstant(ConstantMap
&uniquedConstants
,
296 Dialect
*dialect
, Attribute value
,
297 Type type
, Location loc
) {
298 // Check if an existing mapping already exists.
299 auto constKey
= std::make_tuple(dialect
, value
, type
);
300 Operation
*&constOp
= uniquedConstants
[constKey
];
302 if (loc
!= constOp
->getLoc())
303 constOp
->setLoc(erasedFoldedLocation
);
307 // If one doesn't exist, try to materialize one.
308 if (!(constOp
= materializeConstant(dialect
, rewriter
, value
, type
, loc
)))
311 // Check to see if the generated constant is in the expected dialect.
312 auto *newDialect
= constOp
->getDialect();
313 if (newDialect
== dialect
) {
314 referencedDialects
[constOp
].push_back(dialect
);
318 // If it isn't, then we also need to make sure that the mapping for the new
320 auto newKey
= std::make_tuple(newDialect
, value
, type
);
322 // If an existing operation in the new dialect already exists, delete the
323 // materialized operation in favor of the existing one.
324 if (auto *existingOp
= uniquedConstants
.lookup(newKey
)) {
325 notifyRemoval(constOp
);
326 rewriter
.eraseOp(constOp
);
327 referencedDialects
[existingOp
].push_back(dialect
);
328 if (loc
!= existingOp
->getLoc())
329 existingOp
->setLoc(erasedFoldedLocation
);
330 return constOp
= existingOp
;
333 // Otherwise, update the new dialect to the materialized operation.
334 referencedDialects
[constOp
].assign({dialect
, newDialect
});
335 auto newIt
= uniquedConstants
.insert({newKey
, constOp
});
336 return newIt
.first
->second
;