[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / IR / PatternMatch.cpp
blob286f47ce691368eda7dad85052aaea68311b831a
1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
13 #include "mlir/IR/RegionKindInterface.h"
14 #include "llvm/ADT/SmallPtrSet.h"
16 using namespace mlir;
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit
20 //===----------------------------------------------------------------------===//
22 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24 "This pattern match benefit is too large to represent");
27 unsigned short PatternBenefit::getBenefit() const {
28 assert(!isImpossibleToMatch() && "Pattern doesn't match");
29 return representation;
32 //===----------------------------------------------------------------------===//
33 // Pattern
34 //===----------------------------------------------------------------------===//
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40 MLIRContext *context, ArrayRef<StringRef> generatedNames)
41 : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
42 RootKind::OperationName, generatedNames, benefit, context) {}
44 //===----------------------------------------------------------------------===//
45 // MatchAnyOpTypeTag Root Constructors
47 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
48 MLIRContext *context, ArrayRef<StringRef> generatedNames)
49 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
54 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
55 PatternBenefit benefit, MLIRContext *context,
56 ArrayRef<StringRef> generatedNames)
57 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
58 generatedNames, benefit, context) {}
60 //===----------------------------------------------------------------------===//
61 // MatchTraitOpTypeTag Root Constructors
63 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
64 PatternBenefit benefit, MLIRContext *context,
65 ArrayRef<StringRef> generatedNames)
66 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
67 benefit, context) {}
69 //===----------------------------------------------------------------------===//
70 // General Constructors
72 Pattern::Pattern(const void *rootValue, RootKind rootKind,
73 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
74 MLIRContext *context)
75 : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
76 contextAndHasBoundedRecursion(context, false) {
77 if (generatedNames.empty())
78 return;
79 generatedOps.reserve(generatedNames.size());
80 std::transform(generatedNames.begin(), generatedNames.end(),
81 std::back_inserter(generatedOps), [context](StringRef name) {
82 return OperationName(name, context);
83 });
86 //===----------------------------------------------------------------------===//
87 // RewritePattern
88 //===----------------------------------------------------------------------===//
90 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
91 llvm_unreachable("need to implement either matchAndRewrite or one of the "
92 "rewrite functions!");
95 LogicalResult RewritePattern::match(Operation *op) const {
96 llvm_unreachable("need to implement either match or matchAndRewrite!");
99 /// Out-of-line vtable anchor.
100 void RewritePattern::anchor() {}
102 //===----------------------------------------------------------------------===//
103 // RewriterBase
104 //===----------------------------------------------------------------------===//
106 bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
107 return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
110 RewriterBase::~RewriterBase() {
111 // Out of line to provide a vtable anchor for the class.
114 void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
115 // Notify the listener that we're about to replace this op.
116 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
117 rewriteListener->notifyOperationReplaced(from, to);
119 replaceAllUsesWith(from->getResults(), to);
122 void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
123 // Notify the listener that we're about to replace this op.
124 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
125 rewriteListener->notifyOperationReplaced(from, to);
127 replaceAllUsesWith(from->getResults(), to->getResults());
130 /// This method replaces the results of the operation with the specified list of
131 /// values. The number of provided values must match the number of results of
132 /// the operation. The replaced op is erased.
133 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
134 assert(op->getNumResults() == newValues.size() &&
135 "incorrect # of replacement values");
137 // Replace all result uses. Also notifies the listener of modifications.
138 replaceAllOpUsesWith(op, newValues);
140 // Erase op and notify listener.
141 eraseOp(op);
144 /// This method replaces the results of the operation with the specified new op
145 /// (replacement). The number of results of the two operations must match. The
146 /// replaced op is erased.
147 void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
148 assert(op && newOp && "expected non-null op");
149 assert(op->getNumResults() == newOp->getNumResults() &&
150 "ops have different number of results");
152 // Replace all result uses. Also notifies the listener of modifications.
153 replaceAllOpUsesWith(op, newOp->getResults());
155 // Erase op and notify listener.
156 eraseOp(op);
159 /// This method erases an operation that is known to have no uses. The uses of
160 /// the given operation *must* be known to be dead.
161 void RewriterBase::eraseOp(Operation *op) {
162 assert(op->use_empty() && "expected 'op' to have no uses");
163 auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
165 // Fast path: If no listener is attached, the op can be dropped in one go.
166 if (!rewriteListener) {
167 op->erase();
168 return;
171 // Helper function that erases a single op.
172 auto eraseSingleOp = [&](Operation *op) {
173 #ifndef NDEBUG
174 // All nested ops should have been erased already.
175 assert(
176 llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177 "expected empty regions");
178 // All users should have been erased already if the op is in a region with
179 // SSA dominance.
180 if (!op->use_empty() && op->getParentOp())
181 assert(mayBeGraphRegion(*op->getParentRegion()) &&
182 "expected that op has no uses");
183 #endif // NDEBUG
184 rewriteListener->notifyOperationErased(op);
186 // Explicitly drop all uses in case the op is in a graph region.
187 op->dropAllUses();
188 op->erase();
191 // Nested ops must be erased one-by-one, so that listeners have a consistent
192 // view of the IR every time a notification is triggered. Users must be
193 // erased before definitions. I.e., post-order, reverse dominance.
194 std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195 // Erase nested ops.
196 for (Region &r : llvm::reverse(op->getRegions())) {
197 // Erase all blocks in the right order. Successors should be erased
198 // before predecessors because successor blocks may use values defined
199 // in predecessor blocks. A post-order traversal of blocks within a
200 // region visits successors before predecessors. Repeat the traversal
201 // until the region is empty. (The block graph could be disconnected.)
202 while (!r.empty()) {
203 SmallVector<Block *> erasedBlocks;
204 // Some blocks may have invalid successor, use a set including nullptr
205 // to avoid null pointer.
206 llvm::SmallPtrSet<Block *, 4> visited{nullptr};
207 for (Block *b : llvm::post_order_ext(&r.front(), visited)) {
208 // Visit ops in reverse order.
209 for (Operation &op :
210 llvm::make_early_inc_range(ReverseIterator::makeIterable(*b)))
211 eraseTree(&op);
212 // Do not erase the block immediately. This is not supprted by the
213 // post_order iterator.
214 erasedBlocks.push_back(b);
216 for (Block *b : erasedBlocks) {
217 // Explicitly drop all uses in case there is a cycle in the block
218 // graph.
219 for (BlockArgument bbArg : b->getArguments())
220 bbArg.dropAllUses();
221 b->dropAllUses();
222 eraseBlock(b);
226 // Then erase the enclosing op.
227 eraseSingleOp(op);
230 eraseTree(op);
233 void RewriterBase::eraseBlock(Block *block) {
234 assert(block->use_empty() && "expected 'block' to have no uses");
236 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
237 assert(op.use_empty() && "expected 'op' to have no uses");
238 eraseOp(&op);
241 // Notify the listener that the block is about to be removed.
242 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
243 rewriteListener->notifyBlockErased(block);
245 block->erase();
248 void RewriterBase::finalizeOpModification(Operation *op) {
249 // Notify the listener that the operation was modified.
250 if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
251 rewriteListener->notifyOperationModified(op);
254 void RewriterBase::replaceAllUsesExcept(
255 Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256 return replaceUsesWithIf(from, to, [&](OpOperand &use) {
257 Operation *user = use.getOwner();
258 return !preservedUsers.contains(user);
262 void RewriterBase::replaceUsesWithIf(Value from, Value to,
263 function_ref<bool(OpOperand &)> functor,
264 bool *allUsesReplaced) {
265 bool allReplaced = true;
266 for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
267 bool replace = functor(operand);
268 if (replace)
269 modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
270 allReplaced &= replace;
272 if (allUsesReplaced)
273 *allUsesReplaced = allReplaced;
276 void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
277 function_ref<bool(OpOperand &)> functor,
278 bool *allUsesReplaced) {
279 assert(from.size() == to.size() && "incorrect number of replacements");
280 bool allReplaced = true;
281 for (auto it : llvm::zip_equal(from, to)) {
282 bool r;
283 replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
284 /*allUsesReplaced=*/&r);
285 allReplaced &= r;
287 if (allUsesReplaced)
288 *allUsesReplaced = allReplaced;
291 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
292 Block::iterator before,
293 ValueRange argValues) {
294 assert(argValues.size() == source->getNumArguments() &&
295 "incorrect # of argument replacement values");
297 // The source block will be deleted, so it should not have any users (i.e.,
298 // there should be no predecessors).
299 assert(source->hasNoPredecessors() &&
300 "expected 'source' to have no predecessors");
302 if (dest->end() != before) {
303 // The source block will be inserted in the middle of the dest block, so
304 // the source block should have no successors. Otherwise, the remainder of
305 // the dest block would be unreachable.
306 assert(source->hasNoSuccessors() &&
307 "expected 'source' to have no successors");
308 } else {
309 // The source block will be inserted at the end of the dest block, so the
310 // dest block should have no successors. Otherwise, the inserted operations
311 // will be unreachable.
312 assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
315 // Replace all of the successor arguments with the provided values.
316 for (auto it : llvm::zip(source->getArguments(), argValues))
317 replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
319 // Move operations from the source block to the dest block and erase the
320 // source block.
321 if (!listener) {
322 // Fast path: If no listener is attached, move all operations at once.
323 dest->getOperations().splice(before, source->getOperations());
324 } else {
325 while (!source->empty())
326 moveOpBefore(&source->front(), dest, before);
329 // Erase the source block.
330 assert(source->empty() && "expected 'source' to be empty");
331 eraseBlock(source);
334 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
335 ValueRange argValues) {
336 inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
339 void RewriterBase::mergeBlocks(Block *source, Block *dest,
340 ValueRange argValues) {
341 inlineBlockBefore(source, dest, dest->end(), argValues);
344 /// Split the operations starting at "before" (inclusive) out of the given
345 /// block into a new block, and return it.
346 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
347 // Fast path: If no listener is attached, split the block directly.
348 if (!listener)
349 return block->splitBlock(before);
351 // `createBlock` sets the insertion point at the beginning of the new block.
352 InsertionGuard g(*this);
353 Block *newBlock =
354 createBlock(block->getParent(), std::next(block->getIterator()));
356 // If `before` points to end of the block, no ops should be moved.
357 if (before == block->end())
358 return newBlock;
360 // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361 // Stop when the operation pointed to by `before` has been moved.
362 while (before->getBlock() != newBlock)
363 moveOpBefore(&block->back(), newBlock, newBlock->begin());
365 return newBlock;
368 /// Move the blocks that belong to "region" before the given position in
369 /// another region. The two regions must be different. The caller is in
370 /// charge to update create the operation transferring the control flow to the
371 /// region and pass it the correct block arguments.
372 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
373 Region::iterator before) {
374 // Fast path: If no listener is attached, move all blocks at once.
375 if (!listener) {
376 parent.getBlocks().splice(before, region.getBlocks());
377 return;
380 // Move blocks from the beginning of the region one-by-one.
381 while (!region.empty())
382 moveBlockBefore(&region.front(), &parent, before);
384 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
385 inlineRegionBefore(region, *before->getParent(), before->getIterator());
388 void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389 moveBlockBefore(block, anotherBlock->getParent(),
390 anotherBlock->getIterator());
393 void RewriterBase::moveBlockBefore(Block *block, Region *region,
394 Region::iterator iterator) {
395 Region *currentRegion = block->getParent();
396 Region::iterator nextIterator = std::next(block->getIterator());
397 block->moveBefore(region, iterator);
398 if (listener)
399 listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400 /*previousIt=*/nextIterator);
403 void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
404 moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
407 void RewriterBase::moveOpBefore(Operation *op, Block *block,
408 Block::iterator iterator) {
409 Block *currentBlock = op->getBlock();
410 Block::iterator nextIterator = std::next(op->getIterator());
411 op->moveBefore(block, iterator);
412 if (listener)
413 listener->notifyOperationInserted(
414 op, /*previous=*/InsertPoint(currentBlock, nextIterator));
417 void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
418 moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
421 void RewriterBase::moveOpAfter(Operation *op, Block *block,
422 Block::iterator iterator) {
423 assert(iterator != block->end() && "cannot move after end of block");
424 moveOpBefore(op, block, std::next(iterator));