1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Config/mlir-config.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/Iterators.h"
13 #include "mlir/IR/RegionKindInterface.h"
14 #include "llvm/ADT/SmallPtrSet.h"
18 //===----------------------------------------------------------------------===//
20 //===----------------------------------------------------------------------===//
22 PatternBenefit::PatternBenefit(unsigned benefit
) : representation(benefit
) {
23 assert(representation
== benefit
&& benefit
!= ImpossibleToMatchSentinel
&&
24 "This pattern match benefit is too large to represent");
27 unsigned short PatternBenefit::getBenefit() const {
28 assert(!isImpossibleToMatch() && "Pattern doesn't match");
29 return representation
;
32 //===----------------------------------------------------------------------===//
34 //===----------------------------------------------------------------------===//
36 //===----------------------------------------------------------------------===//
37 // OperationName Root Constructors
39 Pattern::Pattern(StringRef rootName
, PatternBenefit benefit
,
40 MLIRContext
*context
, ArrayRef
<StringRef
> generatedNames
)
41 : Pattern(OperationName(rootName
, context
).getAsOpaquePointer(),
42 RootKind::OperationName
, generatedNames
, benefit
, context
) {}
44 //===----------------------------------------------------------------------===//
45 // MatchAnyOpTypeTag Root Constructors
47 Pattern::Pattern(MatchAnyOpTypeTag tag
, PatternBenefit benefit
,
48 MLIRContext
*context
, ArrayRef
<StringRef
> generatedNames
)
49 : Pattern(nullptr, RootKind::Any
, generatedNames
, benefit
, context
) {}
51 //===----------------------------------------------------------------------===//
52 // MatchInterfaceOpTypeTag Root Constructors
54 Pattern::Pattern(MatchInterfaceOpTypeTag tag
, TypeID interfaceID
,
55 PatternBenefit benefit
, MLIRContext
*context
,
56 ArrayRef
<StringRef
> generatedNames
)
57 : Pattern(interfaceID
.getAsOpaquePointer(), RootKind::InterfaceID
,
58 generatedNames
, benefit
, context
) {}
60 //===----------------------------------------------------------------------===//
61 // MatchTraitOpTypeTag Root Constructors
63 Pattern::Pattern(MatchTraitOpTypeTag tag
, TypeID traitID
,
64 PatternBenefit benefit
, MLIRContext
*context
,
65 ArrayRef
<StringRef
> generatedNames
)
66 : Pattern(traitID
.getAsOpaquePointer(), RootKind::TraitID
, generatedNames
,
69 //===----------------------------------------------------------------------===//
70 // General Constructors
72 Pattern::Pattern(const void *rootValue
, RootKind rootKind
,
73 ArrayRef
<StringRef
> generatedNames
, PatternBenefit benefit
,
75 : rootValue(rootValue
), rootKind(rootKind
), benefit(benefit
),
76 contextAndHasBoundedRecursion(context
, false) {
77 if (generatedNames
.empty())
79 generatedOps
.reserve(generatedNames
.size());
80 std::transform(generatedNames
.begin(), generatedNames
.end(),
81 std::back_inserter(generatedOps
), [context
](StringRef name
) {
82 return OperationName(name
, context
);
86 //===----------------------------------------------------------------------===//
88 //===----------------------------------------------------------------------===//
90 void RewritePattern::rewrite(Operation
*op
, PatternRewriter
&rewriter
) const {
91 llvm_unreachable("need to implement either matchAndRewrite or one of the "
92 "rewrite functions!");
95 LogicalResult
RewritePattern::match(Operation
*op
) const {
96 llvm_unreachable("need to implement either match or matchAndRewrite!");
99 /// Out-of-line vtable anchor.
100 void RewritePattern::anchor() {}
102 //===----------------------------------------------------------------------===//
104 //===----------------------------------------------------------------------===//
106 bool RewriterBase::Listener::classof(const OpBuilder::Listener
*base
) {
107 return base
->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener
;
110 RewriterBase::~RewriterBase() {
111 // Out of line to provide a vtable anchor for the class.
114 void RewriterBase::replaceAllOpUsesWith(Operation
*from
, ValueRange to
) {
115 // Notify the listener that we're about to replace this op.
116 if (auto *rewriteListener
= dyn_cast_if_present
<Listener
>(listener
))
117 rewriteListener
->notifyOperationReplaced(from
, to
);
119 replaceAllUsesWith(from
->getResults(), to
);
122 void RewriterBase::replaceAllOpUsesWith(Operation
*from
, Operation
*to
) {
123 // Notify the listener that we're about to replace this op.
124 if (auto *rewriteListener
= dyn_cast_if_present
<Listener
>(listener
))
125 rewriteListener
->notifyOperationReplaced(from
, to
);
127 replaceAllUsesWith(from
->getResults(), to
->getResults());
130 /// This method replaces the results of the operation with the specified list of
131 /// values. The number of provided values must match the number of results of
132 /// the operation. The replaced op is erased.
133 void RewriterBase::replaceOp(Operation
*op
, ValueRange newValues
) {
134 assert(op
->getNumResults() == newValues
.size() &&
135 "incorrect # of replacement values");
137 // Replace all result uses. Also notifies the listener of modifications.
138 replaceAllOpUsesWith(op
, newValues
);
140 // Erase op and notify listener.
144 /// This method replaces the results of the operation with the specified new op
145 /// (replacement). The number of results of the two operations must match. The
146 /// replaced op is erased.
147 void RewriterBase::replaceOp(Operation
*op
, Operation
*newOp
) {
148 assert(op
&& newOp
&& "expected non-null op");
149 assert(op
->getNumResults() == newOp
->getNumResults() &&
150 "ops have different number of results");
152 // Replace all result uses. Also notifies the listener of modifications.
153 replaceAllOpUsesWith(op
, newOp
->getResults());
155 // Erase op and notify listener.
159 /// This method erases an operation that is known to have no uses. The uses of
160 /// the given operation *must* be known to be dead.
161 void RewriterBase::eraseOp(Operation
*op
) {
162 assert(op
->use_empty() && "expected 'op' to have no uses");
163 auto *rewriteListener
= dyn_cast_if_present
<Listener
>(listener
);
165 // Fast path: If no listener is attached, the op can be dropped in one go.
166 if (!rewriteListener
) {
171 // Helper function that erases a single op.
172 auto eraseSingleOp
= [&](Operation
*op
) {
174 // All nested ops should have been erased already.
176 llvm::all_of(op
->getRegions(), [&](Region
&r
) { return r
.empty(); }) &&
177 "expected empty regions");
178 // All users should have been erased already if the op is in a region with
180 if (!op
->use_empty() && op
->getParentOp())
181 assert(mayBeGraphRegion(*op
->getParentRegion()) &&
182 "expected that op has no uses");
184 rewriteListener
->notifyOperationErased(op
);
186 // Explicitly drop all uses in case the op is in a graph region.
191 // Nested ops must be erased one-by-one, so that listeners have a consistent
192 // view of the IR every time a notification is triggered. Users must be
193 // erased before definitions. I.e., post-order, reverse dominance.
194 std::function
<void(Operation
*)> eraseTree
= [&](Operation
*op
) {
196 for (Region
&r
: llvm::reverse(op
->getRegions())) {
197 // Erase all blocks in the right order. Successors should be erased
198 // before predecessors because successor blocks may use values defined
199 // in predecessor blocks. A post-order traversal of blocks within a
200 // region visits successors before predecessors. Repeat the traversal
201 // until the region is empty. (The block graph could be disconnected.)
203 SmallVector
<Block
*> erasedBlocks
;
204 // Some blocks may have invalid successor, use a set including nullptr
205 // to avoid null pointer.
206 llvm::SmallPtrSet
<Block
*, 4> visited
{nullptr};
207 for (Block
*b
: llvm::post_order_ext(&r
.front(), visited
)) {
208 // Visit ops in reverse order.
210 llvm::make_early_inc_range(ReverseIterator::makeIterable(*b
)))
212 // Do not erase the block immediately. This is not supprted by the
213 // post_order iterator.
214 erasedBlocks
.push_back(b
);
216 for (Block
*b
: erasedBlocks
) {
217 // Explicitly drop all uses in case there is a cycle in the block
219 for (BlockArgument bbArg
: b
->getArguments())
226 // Then erase the enclosing op.
233 void RewriterBase::eraseBlock(Block
*block
) {
234 assert(block
->use_empty() && "expected 'block' to have no uses");
236 for (auto &op
: llvm::make_early_inc_range(llvm::reverse(*block
))) {
237 assert(op
.use_empty() && "expected 'op' to have no uses");
241 // Notify the listener that the block is about to be removed.
242 if (auto *rewriteListener
= dyn_cast_if_present
<Listener
>(listener
))
243 rewriteListener
->notifyBlockErased(block
);
248 void RewriterBase::finalizeOpModification(Operation
*op
) {
249 // Notify the listener that the operation was modified.
250 if (auto *rewriteListener
= dyn_cast_if_present
<Listener
>(listener
))
251 rewriteListener
->notifyOperationModified(op
);
254 void RewriterBase::replaceAllUsesExcept(
255 Value from
, Value to
, const SmallPtrSetImpl
<Operation
*> &preservedUsers
) {
256 return replaceUsesWithIf(from
, to
, [&](OpOperand
&use
) {
257 Operation
*user
= use
.getOwner();
258 return !preservedUsers
.contains(user
);
262 void RewriterBase::replaceUsesWithIf(Value from
, Value to
,
263 function_ref
<bool(OpOperand
&)> functor
,
264 bool *allUsesReplaced
) {
265 bool allReplaced
= true;
266 for (OpOperand
&operand
: llvm::make_early_inc_range(from
.getUses())) {
267 bool replace
= functor(operand
);
269 modifyOpInPlace(operand
.getOwner(), [&]() { operand
.set(to
); });
270 allReplaced
&= replace
;
273 *allUsesReplaced
= allReplaced
;
276 void RewriterBase::replaceUsesWithIf(ValueRange from
, ValueRange to
,
277 function_ref
<bool(OpOperand
&)> functor
,
278 bool *allUsesReplaced
) {
279 assert(from
.size() == to
.size() && "incorrect number of replacements");
280 bool allReplaced
= true;
281 for (auto it
: llvm::zip_equal(from
, to
)) {
283 replaceUsesWithIf(std::get
<0>(it
), std::get
<1>(it
), functor
,
284 /*allUsesReplaced=*/&r
);
288 *allUsesReplaced
= allReplaced
;
291 void RewriterBase::inlineBlockBefore(Block
*source
, Block
*dest
,
292 Block::iterator before
,
293 ValueRange argValues
) {
294 assert(argValues
.size() == source
->getNumArguments() &&
295 "incorrect # of argument replacement values");
297 // The source block will be deleted, so it should not have any users (i.e.,
298 // there should be no predecessors).
299 assert(source
->hasNoPredecessors() &&
300 "expected 'source' to have no predecessors");
302 if (dest
->end() != before
) {
303 // The source block will be inserted in the middle of the dest block, so
304 // the source block should have no successors. Otherwise, the remainder of
305 // the dest block would be unreachable.
306 assert(source
->hasNoSuccessors() &&
307 "expected 'source' to have no successors");
309 // The source block will be inserted at the end of the dest block, so the
310 // dest block should have no successors. Otherwise, the inserted operations
311 // will be unreachable.
312 assert(dest
->hasNoSuccessors() && "expected 'dest' to have no successors");
315 // Replace all of the successor arguments with the provided values.
316 for (auto it
: llvm::zip(source
->getArguments(), argValues
))
317 replaceAllUsesWith(std::get
<0>(it
), std::get
<1>(it
));
319 // Move operations from the source block to the dest block and erase the
322 // Fast path: If no listener is attached, move all operations at once.
323 dest
->getOperations().splice(before
, source
->getOperations());
325 while (!source
->empty())
326 moveOpBefore(&source
->front(), dest
, before
);
329 // Erase the source block.
330 assert(source
->empty() && "expected 'source' to be empty");
334 void RewriterBase::inlineBlockBefore(Block
*source
, Operation
*op
,
335 ValueRange argValues
) {
336 inlineBlockBefore(source
, op
->getBlock(), op
->getIterator(), argValues
);
339 void RewriterBase::mergeBlocks(Block
*source
, Block
*dest
,
340 ValueRange argValues
) {
341 inlineBlockBefore(source
, dest
, dest
->end(), argValues
);
344 /// Split the operations starting at "before" (inclusive) out of the given
345 /// block into a new block, and return it.
346 Block
*RewriterBase::splitBlock(Block
*block
, Block::iterator before
) {
347 // Fast path: If no listener is attached, split the block directly.
349 return block
->splitBlock(before
);
351 // `createBlock` sets the insertion point at the beginning of the new block.
352 InsertionGuard
g(*this);
354 createBlock(block
->getParent(), std::next(block
->getIterator()));
356 // If `before` points to end of the block, no ops should be moved.
357 if (before
== block
->end())
360 // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361 // Stop when the operation pointed to by `before` has been moved.
362 while (before
->getBlock() != newBlock
)
363 moveOpBefore(&block
->back(), newBlock
, newBlock
->begin());
368 /// Move the blocks that belong to "region" before the given position in
369 /// another region. The two regions must be different. The caller is in
370 /// charge to update create the operation transferring the control flow to the
371 /// region and pass it the correct block arguments.
372 void RewriterBase::inlineRegionBefore(Region
®ion
, Region
&parent
,
373 Region::iterator before
) {
374 // Fast path: If no listener is attached, move all blocks at once.
376 parent
.getBlocks().splice(before
, region
.getBlocks());
380 // Move blocks from the beginning of the region one-by-one.
381 while (!region
.empty())
382 moveBlockBefore(®ion
.front(), &parent
, before
);
384 void RewriterBase::inlineRegionBefore(Region
®ion
, Block
*before
) {
385 inlineRegionBefore(region
, *before
->getParent(), before
->getIterator());
388 void RewriterBase::moveBlockBefore(Block
*block
, Block
*anotherBlock
) {
389 moveBlockBefore(block
, anotherBlock
->getParent(),
390 anotherBlock
->getIterator());
393 void RewriterBase::moveBlockBefore(Block
*block
, Region
*region
,
394 Region::iterator iterator
) {
395 Region
*currentRegion
= block
->getParent();
396 Region::iterator nextIterator
= std::next(block
->getIterator());
397 block
->moveBefore(region
, iterator
);
399 listener
->notifyBlockInserted(block
, /*previous=*/currentRegion
,
400 /*previousIt=*/nextIterator
);
403 void RewriterBase::moveOpBefore(Operation
*op
, Operation
*existingOp
) {
404 moveOpBefore(op
, existingOp
->getBlock(), existingOp
->getIterator());
407 void RewriterBase::moveOpBefore(Operation
*op
, Block
*block
,
408 Block::iterator iterator
) {
409 Block
*currentBlock
= op
->getBlock();
410 Block::iterator nextIterator
= std::next(op
->getIterator());
411 op
->moveBefore(block
, iterator
);
413 listener
->notifyOperationInserted(
414 op
, /*previous=*/InsertPoint(currentBlock
, nextIterator
));
417 void RewriterBase::moveOpAfter(Operation
*op
, Operation
*existingOp
) {
418 moveOpAfter(op
, existingOp
->getBlock(), existingOp
->getIterator());
421 void RewriterBase::moveOpAfter(Operation
*op
, Block
*block
,
422 Block::iterator iterator
) {
423 assert(iterator
!= block
->end() && "cannot move after end of block");
424 moveOpBefore(op
, block
, std::next(iterator
));