[gn build] Port fef54d0393fd
[llvm-project.git] / mlir / lib / Transforms / Utils / FoldUtils.cpp
blobc43f439525526bb51dcd2ab13c0c7ac244b7481e
1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
20 using namespace mlir;
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26 Block *insertionBlock) {
27 while (Region *region = insertionBlock->getParent()) {
28 // Insert in this region for any of the following scenarios:
29 // * The parent is unregistered, or is known to be isolated from above.
30 // * The parent is a top-level operation.
31 auto *parentOp = region->getParentOp();
32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33 !parentOp->getBlock())
34 return region;
36 // Otherwise, check if this region is a desired insertion region.
37 auto *interface = interfaces.getInterfaceFor(parentOp);
38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39 return region;
41 // Traverse up the parent looking for an insertion region.
42 insertionBlock = parentOp->getBlock();
44 llvm_unreachable("expected valid insertion region");
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51 Attribute value, Type type,
52 Location loc) {
53 auto insertPt = builder.getInsertionPoint();
54 (void)insertPt;
56 // Ask the dialect to materialize a constant operation for this value.
57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58 assert(insertPt == builder.getInsertionPoint());
59 assert(matchPattern(constOp, m_Constant()));
60 return constOp;
63 return nullptr;
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
70 LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
71 if (inPlaceUpdate)
72 *inPlaceUpdate = false;
74 // If this is a unique'd constant, return failure as we know that it has
75 // already been folded.
76 if (isFolderOwnedConstant(op)) {
77 // Check to see if we should rehoist, i.e. if a non-constant operation was
78 // inserted before this one.
79 Block *opBlock = op->getBlock();
80 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
81 op->moveBefore(&opBlock->front());
82 op->setLoc(erasedFoldedLocation);
84 return failure();
87 // Try to fold the operation.
88 SmallVector<Value, 8> results;
89 if (failed(tryToFold(op, results)))
90 return failure();
92 // Check to see if the operation was just updated in place.
93 if (results.empty()) {
94 if (inPlaceUpdate)
95 *inPlaceUpdate = true;
96 if (auto *rewriteListener = dyn_cast_if_present<RewriterBase::Listener>(
97 rewriter.getListener())) {
98 // Folding API does not notify listeners, so we have to notify manually.
99 rewriteListener->notifyOperationModified(op);
101 return success();
104 // Constant folding succeeded. Replace all of the result values and erase the
105 // operation.
106 notifyRemoval(op);
107 rewriter.replaceOp(op, results);
108 return success();
111 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
112 Block *opBlock = op->getBlock();
114 // If this is a constant we unique'd, we don't need to insert, but we can
115 // check to see if we should rehoist it.
116 if (isFolderOwnedConstant(op)) {
117 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) {
118 op->moveBefore(&opBlock->front());
119 op->setLoc(erasedFoldedLocation);
121 return true;
124 // Get the constant value of the op if necessary.
125 if (!constValue) {
126 matchPattern(op, m_Constant(&constValue));
127 assert(constValue && "expected `op` to be a constant");
128 } else {
129 // Ensure that the provided constant was actually correct.
130 #ifndef NDEBUG
131 Attribute expectedValue;
132 matchPattern(op, m_Constant(&expectedValue));
133 assert(
134 expectedValue == constValue &&
135 "provided constant value was not the expected value of the constant");
136 #endif
139 // Check for an existing constant operation for the attribute value.
140 Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141 auto &uniquedConstants = foldScopes[insertRegion];
142 Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143 op->getDialect(), constValue, *op->result_type_begin())];
145 // If there is an existing constant, replace `op`.
146 if (folderConstOp) {
147 notifyRemoval(op);
148 rewriter.replaceOp(op, folderConstOp->getResults());
149 folderConstOp->setLoc(erasedFoldedLocation);
150 return false;
153 // Otherwise, we insert `op`. If `op` is in the insertion block and is either
154 // already at the front of the block, or the previous operation is already a
155 // constant we unique'd (i.e. one we inserted), then we don't need to do
156 // anything. Otherwise, we move the constant to the insertion block.
157 Block *insertBlock = &insertRegion->front();
158 if (opBlock != insertBlock || (&insertBlock->front() != op &&
159 !isFolderOwnedConstant(op->getPrevNode()))) {
160 op->moveBefore(&insertBlock->front());
161 op->setLoc(erasedFoldedLocation);
164 folderConstOp = op;
165 referencedDialects[op].push_back(op->getDialect());
166 return true;
169 /// Notifies that the given constant `op` should be remove from this
170 /// OperationFolder's internal bookkeeping.
171 void OperationFolder::notifyRemoval(Operation *op) {
172 // Check to see if this operation is uniqued within the folder.
173 auto it = referencedDialects.find(op);
174 if (it == referencedDialects.end())
175 return;
177 // Get the constant value for this operation, this is the value that was used
178 // to unique the operation internally.
179 Attribute constValue;
180 matchPattern(op, m_Constant(&constValue));
181 assert(constValue);
183 // Get the constant map that this operation was uniqued in.
184 auto &uniquedConstants =
185 foldScopes[getInsertionRegion(interfaces, op->getBlock())];
187 // Erase all of the references to this operation.
188 auto type = op->getResult(0).getType();
189 for (auto *dialect : it->second)
190 uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
191 referencedDialects.erase(it);
194 /// Clear out any constants cached inside of the folder.
195 void OperationFolder::clear() {
196 foldScopes.clear();
197 referencedDialects.clear();
200 /// Get or create a constant using the given builder. On success this returns
201 /// the constant operation, nullptr otherwise.
202 Value OperationFolder::getOrCreateConstant(Block *block, Dialect *dialect,
203 Attribute value, Type type) {
204 // Find an insertion point for the constant.
205 auto *insertRegion = getInsertionRegion(interfaces, block);
206 auto &entry = insertRegion->front();
207 rewriter.setInsertionPointToStart(&entry);
209 // Get the constant map for the insertion region of this operation.
210 // Use erased location since the op is being built at the front of block.
211 auto &uniquedConstants = foldScopes[insertRegion];
212 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, value,
213 type, erasedFoldedLocation);
214 return constOp ? constOp->getResult(0) : Value();
217 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
218 return referencedDialects.count(op);
221 /// Tries to perform folding on the given `op`. If successful, populates
222 /// `results` with the results of the folding.
223 LogicalResult OperationFolder::tryToFold(Operation *op,
224 SmallVectorImpl<Value> &results) {
225 SmallVector<OpFoldResult, 8> foldResults;
226 if (failed(op->fold(foldResults)) ||
227 failed(processFoldResults(op, results, foldResults)))
228 return failure();
229 return success();
232 LogicalResult
233 OperationFolder::processFoldResults(Operation *op,
234 SmallVectorImpl<Value> &results,
235 ArrayRef<OpFoldResult> foldResults) {
236 // Check to see if the operation was just updated in place.
237 if (foldResults.empty())
238 return success();
239 assert(foldResults.size() == op->getNumResults());
241 // Create a builder to insert new operations into the entry block of the
242 // insertion region.
243 auto *insertRegion = getInsertionRegion(interfaces, op->getBlock());
244 auto &entry = insertRegion->front();
245 rewriter.setInsertionPointToStart(&entry);
247 // Get the constant map for the insertion region of this operation.
248 auto &uniquedConstants = foldScopes[insertRegion];
250 // Create the result constants and replace the results.
251 auto *dialect = op->getDialect();
252 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
253 assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
255 // Check if the result was an SSA value.
256 if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
257 results.emplace_back(repl);
258 continue;
261 // Check to see if there is a canonicalized version of this constant.
262 auto res = op->getResult(i);
263 Attribute attrRepl = foldResults[i].get<Attribute>();
264 if (auto *constOp =
265 tryGetOrCreateConstant(uniquedConstants, dialect, attrRepl,
266 res.getType(), erasedFoldedLocation)) {
267 // Ensure that this constant dominates the operation we are replacing it
268 // with. This may not automatically happen if the operation being folded
269 // was inserted before the constant within the insertion block.
270 Block *opBlock = op->getBlock();
271 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
272 constOp->moveBefore(&opBlock->front());
274 results.push_back(constOp->getResult(0));
275 continue;
277 // If materialization fails, cleanup any operations generated for the
278 // previous results and return failure.
279 for (Operation &op : llvm::make_early_inc_range(
280 llvm::make_range(entry.begin(), rewriter.getInsertionPoint()))) {
281 notifyRemoval(&op);
282 rewriter.eraseOp(&op);
285 results.clear();
286 return failure();
289 return success();
292 /// Try to get or create a new constant entry. On success this returns the
293 /// constant operation value, nullptr otherwise.
294 Operation *
295 OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
296 Dialect *dialect, Attribute value,
297 Type type, Location loc) {
298 // Check if an existing mapping already exists.
299 auto constKey = std::make_tuple(dialect, value, type);
300 Operation *&constOp = uniquedConstants[constKey];
301 if (constOp) {
302 if (loc != constOp->getLoc())
303 constOp->setLoc(erasedFoldedLocation);
304 return constOp;
307 // If one doesn't exist, try to materialize one.
308 if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
309 return nullptr;
311 // Check to see if the generated constant is in the expected dialect.
312 auto *newDialect = constOp->getDialect();
313 if (newDialect == dialect) {
314 referencedDialects[constOp].push_back(dialect);
315 return constOp;
318 // If it isn't, then we also need to make sure that the mapping for the new
319 // dialect is valid.
320 auto newKey = std::make_tuple(newDialect, value, type);
322 // If an existing operation in the new dialect already exists, delete the
323 // materialized operation in favor of the existing one.
324 if (auto *existingOp = uniquedConstants.lookup(newKey)) {
325 notifyRemoval(constOp);
326 rewriter.eraseOp(constOp);
327 referencedDialects[existingOp].push_back(dialect);
328 if (loc != existingOp->getLoc())
329 existingOp->setLoc(erasedFoldedLocation);
330 return constOp = existingOp;
333 // Otherwise, update the new dialect to the materialized operation.
334 referencedDialects[constOp].assign({dialect, newDialect});
335 auto newIt = uniquedConstants.insert({newKey, constOp});
336 return newIt.first->second;