1 //===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This transformation pass performs a simple common sub-expression elimination
10 // algorithm on operations within a region.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/CSE.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/ScopedHashTable.h"
24 #include "llvm/Support/Allocator.h"
25 #include "llvm/Support/RecyclingAllocator.h"
29 #define GEN_PASS_DEF_CSE
30 #include "mlir/Transforms/Passes.h.inc"
36 struct SimpleOperationInfo
: public llvm::DenseMapInfo
<Operation
*> {
37 static unsigned getHashValue(const Operation
*opC
) {
38 return OperationEquivalence::computeHash(
39 const_cast<Operation
*>(opC
),
40 /*hashOperands=*/OperationEquivalence::directHashValue
,
41 /*hashResults=*/OperationEquivalence::ignoreHashValue
,
42 OperationEquivalence::IgnoreLocations
);
44 static bool isEqual(const Operation
*lhsC
, const Operation
*rhsC
) {
45 auto *lhs
= const_cast<Operation
*>(lhsC
);
46 auto *rhs
= const_cast<Operation
*>(rhsC
);
49 if (lhs
== getTombstoneKey() || lhs
== getEmptyKey() ||
50 rhs
== getTombstoneKey() || rhs
== getEmptyKey())
52 return OperationEquivalence::isEquivalentTo(
53 const_cast<Operation
*>(lhsC
), const_cast<Operation
*>(rhsC
),
54 OperationEquivalence::IgnoreLocations
);
60 /// Simple common sub-expression elimination.
63 CSEDriver(RewriterBase
&rewriter
, DominanceInfo
*domInfo
)
64 : rewriter(rewriter
), domInfo(domInfo
) {}
66 /// Simplify all operations within the given op.
67 void simplify(Operation
*op
, bool *changed
= nullptr);
69 int64_t getNumCSE() const { return numCSE
; }
70 int64_t getNumDCE() const { return numDCE
; }
73 /// Shared implementation of operation elimination and scoped map definitions.
74 using AllocatorTy
= llvm::RecyclingAllocator
<
75 llvm::BumpPtrAllocator
,
76 llvm::ScopedHashTableVal
<Operation
*, Operation
*>>;
77 using ScopedMapTy
= llvm::ScopedHashTable
<Operation
*, Operation
*,
78 SimpleOperationInfo
, AllocatorTy
>;
80 /// Cache holding MemoryEffects information between two operations. The first
81 /// operation is stored has the key. The second operation is stored inside a
82 /// pair in the value. The pair also hold the MemoryEffects between those
83 /// two operations. If the MemoryEffects is nullptr then we assume there is
84 /// no operation with MemoryEffects::Write between the two operations.
85 using MemEffectsCache
=
86 DenseMap
<Operation
*, std::pair
<Operation
*, MemoryEffects::Effect
*>>;
88 /// Represents a single entry in the depth first traversal of a CFG.
90 CFGStackNode(ScopedMapTy
&knownValues
, DominanceInfoNode
*node
)
91 : scope(knownValues
), node(node
), childIterator(node
->begin()) {}
93 /// Scope for the known values.
94 ScopedMapTy::ScopeTy scope
;
96 DominanceInfoNode
*node
;
97 DominanceInfoNode::const_iterator childIterator
;
99 /// If this node has been fully processed yet or not.
100 bool processed
= false;
103 /// Attempt to eliminate a redundant operation. Returns success if the
104 /// operation was marked for removal, failure otherwise.
105 LogicalResult
simplifyOperation(ScopedMapTy
&knownValues
, Operation
*op
,
106 bool hasSSADominance
);
107 void simplifyBlock(ScopedMapTy
&knownValues
, Block
*bb
, bool hasSSADominance
);
108 void simplifyRegion(ScopedMapTy
&knownValues
, Region
®ion
);
110 void replaceUsesAndDelete(ScopedMapTy
&knownValues
, Operation
*op
,
111 Operation
*existing
, bool hasSSADominance
);
113 /// Check if there is side-effecting operations other than the given effect
114 /// between the two operations.
115 bool hasOtherSideEffectingOpInBetween(Operation
*fromOp
, Operation
*toOp
);
117 /// A rewriter for modifying the IR.
118 RewriterBase
&rewriter
;
120 /// Operations marked as dead and to be erased.
121 std::vector
<Operation
*> opsToErase
;
122 DominanceInfo
*domInfo
= nullptr;
123 MemEffectsCache memEffectsCache
;
125 // Various statistics.
131 void CSEDriver::replaceUsesAndDelete(ScopedMapTy
&knownValues
, Operation
*op
,
133 bool hasSSADominance
) {
134 // If we find one then replace all uses of the current operation with the
135 // existing one and mark it for deletion. We can only replace an operand in
136 // an operation if it has not been visited yet.
137 if (hasSSADominance
) {
138 // If the region has SSA dominance, then we are guaranteed to have not
139 // visited any use of the current operation.
140 if (auto *rewriteListener
=
141 dyn_cast_if_present
<RewriterBase::Listener
>(rewriter
.getListener()))
142 rewriteListener
->notifyOperationReplaced(op
, existing
);
143 // Replace all uses, but do not remote the operation yet. This does not
144 // notify the listener because the original op is not erased.
145 rewriter
.replaceAllUsesWith(op
->getResults(), existing
->getResults());
146 opsToErase
.push_back(op
);
148 // When the region does not have SSA dominance, we need to check if we
149 // have visited a use before replacing any use.
150 auto wasVisited
= [&](OpOperand
&operand
) {
151 return !knownValues
.count(operand
.getOwner());
153 if (auto *rewriteListener
=
154 dyn_cast_if_present
<RewriterBase::Listener
>(rewriter
.getListener()))
155 for (Value v
: op
->getResults())
156 if (all_of(v
.getUses(), wasVisited
))
157 rewriteListener
->notifyOperationReplaced(op
, existing
);
159 // Replace all uses, but do not remote the operation yet. This does not
160 // notify the listener because the original op is not erased.
161 rewriter
.replaceUsesWithIf(op
->getResults(), existing
->getResults(),
164 // There may be some remaining uses of the operation.
166 opsToErase
.push_back(op
);
169 // If the existing operation has an unknown location and the current
170 // operation doesn't, then set the existing op's location to that of the
172 if (isa
<UnknownLoc
>(existing
->getLoc()) && !isa
<UnknownLoc
>(op
->getLoc()))
173 existing
->setLoc(op
->getLoc());
178 bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation
*fromOp
,
180 assert(fromOp
->getBlock() == toOp
->getBlock());
182 isa
<MemoryEffectOpInterface
>(fromOp
) &&
183 cast
<MemoryEffectOpInterface
>(fromOp
).hasEffect
<MemoryEffects::Read
>() &&
184 isa
<MemoryEffectOpInterface
>(toOp
) &&
185 cast
<MemoryEffectOpInterface
>(toOp
).hasEffect
<MemoryEffects::Read
>());
186 Operation
*nextOp
= fromOp
->getNextNode();
188 memEffectsCache
.try_emplace(fromOp
, std::make_pair(fromOp
, nullptr));
190 auto memEffectsCachePair
= result
.first
->second
;
191 if (memEffectsCachePair
.second
== nullptr) {
192 // No MemoryEffects::Write has been detected until the cached operation.
193 // Continue looking from the cached operation to toOp.
194 nextOp
= memEffectsCachePair
.first
;
196 // MemoryEffects::Write has been detected before so there is no need to
201 while (nextOp
&& nextOp
!= toOp
) {
202 std::optional
<SmallVector
<MemoryEffects::EffectInstance
>> effects
=
203 getEffectsRecursively(nextOp
);
205 // TODO: Do we need to handle other effects generically?
206 // If the operation does not implement the MemoryEffectOpInterface we
207 // conservatively assume it writes.
208 result
.first
->second
=
209 std::make_pair(nextOp
, MemoryEffects::Write::get());
213 for (const MemoryEffects::EffectInstance
&effect
: *effects
) {
214 if (isa
<MemoryEffects::Write
>(effect
.getEffect())) {
215 result
.first
->second
= {nextOp
, MemoryEffects::Write::get()};
219 nextOp
= nextOp
->getNextNode();
221 result
.first
->second
= std::make_pair(toOp
, nullptr);
225 /// Attempt to eliminate a redundant operation.
226 LogicalResult
CSEDriver::simplifyOperation(ScopedMapTy
&knownValues
,
228 bool hasSSADominance
) {
229 // Don't simplify terminator operations.
230 if (op
->hasTrait
<OpTrait::IsTerminator
>())
233 // If the operation is already trivially dead just add it to the erase list.
234 if (isOpTriviallyDead(op
)) {
235 opsToErase
.push_back(op
);
240 // Don't simplify operations with regions that have multiple blocks.
241 // TODO: We need additional tests to verify that we handle such IR correctly.
242 if (!llvm::all_of(op
->getRegions(), [](Region
&r
) {
243 return r
.getBlocks().empty() || llvm::hasSingleElement(r
.getBlocks());
247 // Some simple use case of operation with memory side-effect are dealt with
248 // here. Operations with no side-effect are done after.
249 if (!isMemoryEffectFree(op
)) {
250 auto memEffects
= dyn_cast
<MemoryEffectOpInterface
>(op
);
251 // TODO: Only basic use case for operations with MemoryEffects::Read can be
252 // eleminated now. More work needs to be done for more complicated patterns
253 // and other side-effects.
254 if (!memEffects
|| !memEffects
.onlyHasEffect
<MemoryEffects::Read
>())
257 // Look for an existing definition for the operation.
258 if (auto *existing
= knownValues
.lookup(op
)) {
259 if (existing
->getBlock() == op
->getBlock() &&
260 !hasOtherSideEffectingOpInBetween(existing
, op
)) {
261 // The operation that can be deleted has been reach with no
262 // side-effecting operations in between the existing operation and
263 // this one so we can remove the duplicate.
264 replaceUsesAndDelete(knownValues
, op
, existing
, hasSSADominance
);
268 knownValues
.insert(op
, op
);
272 // Look for an existing definition for the operation.
273 if (auto *existing
= knownValues
.lookup(op
)) {
274 replaceUsesAndDelete(knownValues
, op
, existing
, hasSSADominance
);
279 // Otherwise, we add this operation to the known values map.
280 knownValues
.insert(op
, op
);
284 void CSEDriver::simplifyBlock(ScopedMapTy
&knownValues
, Block
*bb
,
285 bool hasSSADominance
) {
286 for (auto &op
: *bb
) {
287 // Most operations don't have regions, so fast path that case.
288 if (op
.getNumRegions() != 0) {
289 // If this operation is isolated above, we can't process nested regions
290 // with the given 'knownValues' map. This would cause the insertion of
291 // implicit captures in explicit capture only regions.
292 if (op
.mightHaveTrait
<OpTrait::IsIsolatedFromAbove
>()) {
293 ScopedMapTy nestedKnownValues
;
294 for (auto ®ion
: op
.getRegions())
295 simplifyRegion(nestedKnownValues
, region
);
297 // Otherwise, process nested regions normally.
298 for (auto ®ion
: op
.getRegions())
299 simplifyRegion(knownValues
, region
);
303 // If the operation is simplified, we don't process any held regions.
304 if (succeeded(simplifyOperation(knownValues
, &op
, hasSSADominance
)))
307 // Clear the MemoryEffects cache since its usage is by block only.
308 memEffectsCache
.clear();
311 void CSEDriver::simplifyRegion(ScopedMapTy
&knownValues
, Region
®ion
) {
312 // If the region is empty there is nothing to do.
316 bool hasSSADominance
= domInfo
->hasSSADominance(®ion
);
318 // If the region only contains one block, then simplify it directly.
319 if (region
.hasOneBlock()) {
320 ScopedMapTy::ScopeTy
scope(knownValues
);
321 simplifyBlock(knownValues
, ®ion
.front(), hasSSADominance
);
325 // If the region does not have dominanceInfo, then skip it.
326 // TODO: Regions without SSA dominance should define a different
327 // traversal order which is appropriate and can be used here.
328 if (!hasSSADominance
)
331 // Note, deque is being used here because there was significant performance
332 // gains over vector when the container becomes very large due to the
333 // specific access patterns. If/when these performance issues are no
334 // longer a problem we can change this to vector. For more information see
335 // the llvm mailing list discussion on this:
336 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
337 std::deque
<std::unique_ptr
<CFGStackNode
>> stack
;
339 // Process the nodes of the dom tree for this region.
340 stack
.emplace_back(std::make_unique
<CFGStackNode
>(
341 knownValues
, domInfo
->getRootNode(®ion
)));
343 while (!stack
.empty()) {
344 auto ¤tNode
= stack
.back();
346 // Check to see if we need to process this node.
347 if (!currentNode
->processed
) {
348 currentNode
->processed
= true;
349 simplifyBlock(knownValues
, currentNode
->node
->getBlock(),
353 // Otherwise, check to see if we need to process a child node.
354 if (currentNode
->childIterator
!= currentNode
->node
->end()) {
355 auto *childNode
= *(currentNode
->childIterator
++);
357 std::make_unique
<CFGStackNode
>(knownValues
, childNode
));
359 // Finally, if the node and all of its children have been processed
360 // then we delete the node.
366 void CSEDriver::simplify(Operation
*op
, bool *changed
) {
367 /// Simplify all regions.
368 ScopedMapTy knownValues
;
369 for (auto ®ion
: op
->getRegions())
370 simplifyRegion(knownValues
, region
);
372 /// Erase any operations that were marked as dead during simplification.
373 for (auto *op
: opsToErase
)
374 rewriter
.eraseOp(op
);
376 *changed
= !opsToErase
.empty();
378 // Note: CSE does currently not remove ops with regions, so DominanceInfo
379 // does not have to be invalidated.
382 void mlir::eliminateCommonSubExpressions(RewriterBase
&rewriter
,
383 DominanceInfo
&domInfo
, Operation
*op
,
385 CSEDriver
driver(rewriter
, &domInfo
);
386 driver
.simplify(op
, changed
);
391 struct CSE
: public impl::CSEBase
<CSE
> {
392 void runOnOperation() override
;
396 void CSE::runOnOperation() {
398 IRRewriter
rewriter(&getContext());
399 CSEDriver
driver(rewriter
, &getAnalysis
<DominanceInfo
>());
400 bool changed
= false;
401 driver
.simplify(getOperation(), &changed
);
404 numCSE
= driver
.getNumCSE();
405 numDCE
= driver
.getNumDCE();
407 // If there was no change to the IR, we mark all analyses as preserved.
409 return markAllAnalysesPreserved();
411 // We currently don't remove region operations, so mark dominance as
413 markAnalysesPreserved
<DominanceInfo
, PostDominanceInfo
>();
416 std::unique_ptr
<Pass
> mlir::createCSEPass() { return std::make_unique
<CSE
>(); }