IR: de-duplicate two CmpInst routines (NFC) (#116866)
[llvm-project.git] / mlir / lib / Transforms / CSE.cpp
blob3affd88d158de593f8a68a100d5c05d875f4e1ee
1 //===- CSE.cpp - Common Sub-expression Elimination ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a simple common sub-expression elimination
10 // algorithm on operations within a region.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Transforms/CSE.h"
16 #include "mlir/IR/Dominance.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/Passes.h"
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/ScopedHashTable.h"
24 #include "llvm/Support/Allocator.h"
25 #include "llvm/Support/RecyclingAllocator.h"
26 #include <deque>
28 namespace mlir {
29 #define GEN_PASS_DEF_CSE
30 #include "mlir/Transforms/Passes.h.inc"
31 } // namespace mlir
33 using namespace mlir;
35 namespace {
36 struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
37 static unsigned getHashValue(const Operation *opC) {
38 return OperationEquivalence::computeHash(
39 const_cast<Operation *>(opC),
40 /*hashOperands=*/OperationEquivalence::directHashValue,
41 /*hashResults=*/OperationEquivalence::ignoreHashValue,
42 OperationEquivalence::IgnoreLocations);
44 static bool isEqual(const Operation *lhsC, const Operation *rhsC) {
45 auto *lhs = const_cast<Operation *>(lhsC);
46 auto *rhs = const_cast<Operation *>(rhsC);
47 if (lhs == rhs)
48 return true;
49 if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
50 rhs == getTombstoneKey() || rhs == getEmptyKey())
51 return false;
52 return OperationEquivalence::isEquivalentTo(
53 const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
54 OperationEquivalence::IgnoreLocations);
57 } // namespace
59 namespace {
60 /// Simple common sub-expression elimination.
61 class CSEDriver {
62 public:
63 CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
64 : rewriter(rewriter), domInfo(domInfo) {}
66 /// Simplify all operations within the given op.
67 void simplify(Operation *op, bool *changed = nullptr);
69 int64_t getNumCSE() const { return numCSE; }
70 int64_t getNumDCE() const { return numDCE; }
72 private:
73 /// Shared implementation of operation elimination and scoped map definitions.
74 using AllocatorTy = llvm::RecyclingAllocator<
75 llvm::BumpPtrAllocator,
76 llvm::ScopedHashTableVal<Operation *, Operation *>>;
77 using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
78 SimpleOperationInfo, AllocatorTy>;
80 /// Cache holding MemoryEffects information between two operations. The first
81 /// operation is stored has the key. The second operation is stored inside a
82 /// pair in the value. The pair also hold the MemoryEffects between those
83 /// two operations. If the MemoryEffects is nullptr then we assume there is
84 /// no operation with MemoryEffects::Write between the two operations.
85 using MemEffectsCache =
86 DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
88 /// Represents a single entry in the depth first traversal of a CFG.
89 struct CFGStackNode {
90 CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
91 : scope(knownValues), node(node), childIterator(node->begin()) {}
93 /// Scope for the known values.
94 ScopedMapTy::ScopeTy scope;
96 DominanceInfoNode *node;
97 DominanceInfoNode::const_iterator childIterator;
99 /// If this node has been fully processed yet or not.
100 bool processed = false;
103 /// Attempt to eliminate a redundant operation. Returns success if the
104 /// operation was marked for removal, failure otherwise.
105 LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op,
106 bool hasSSADominance);
107 void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance);
108 void simplifyRegion(ScopedMapTy &knownValues, Region &region);
110 void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
111 Operation *existing, bool hasSSADominance);
113 /// Check if there is side-effecting operations other than the given effect
114 /// between the two operations.
115 bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
117 /// A rewriter for modifying the IR.
118 RewriterBase &rewriter;
120 /// Operations marked as dead and to be erased.
121 std::vector<Operation *> opsToErase;
122 DominanceInfo *domInfo = nullptr;
123 MemEffectsCache memEffectsCache;
125 // Various statistics.
126 int64_t numCSE = 0;
127 int64_t numDCE = 0;
129 } // namespace
131 void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
132 Operation *existing,
133 bool hasSSADominance) {
134 // If we find one then replace all uses of the current operation with the
135 // existing one and mark it for deletion. We can only replace an operand in
136 // an operation if it has not been visited yet.
137 if (hasSSADominance) {
138 // If the region has SSA dominance, then we are guaranteed to have not
139 // visited any use of the current operation.
140 if (auto *rewriteListener =
141 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
142 rewriteListener->notifyOperationReplaced(op, existing);
143 // Replace all uses, but do not remote the operation yet. This does not
144 // notify the listener because the original op is not erased.
145 rewriter.replaceAllUsesWith(op->getResults(), existing->getResults());
146 opsToErase.push_back(op);
147 } else {
148 // When the region does not have SSA dominance, we need to check if we
149 // have visited a use before replacing any use.
150 auto wasVisited = [&](OpOperand &operand) {
151 return !knownValues.count(operand.getOwner());
153 if (auto *rewriteListener =
154 dyn_cast_if_present<RewriterBase::Listener>(rewriter.getListener()))
155 for (Value v : op->getResults())
156 if (all_of(v.getUses(), wasVisited))
157 rewriteListener->notifyOperationReplaced(op, existing);
159 // Replace all uses, but do not remote the operation yet. This does not
160 // notify the listener because the original op is not erased.
161 rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(),
162 wasVisited);
164 // There may be some remaining uses of the operation.
165 if (op->use_empty())
166 opsToErase.push_back(op);
169 // If the existing operation has an unknown location and the current
170 // operation doesn't, then set the existing op's location to that of the
171 // current op.
172 if (isa<UnknownLoc>(existing->getLoc()) && !isa<UnknownLoc>(op->getLoc()))
173 existing->setLoc(op->getLoc());
175 ++numCSE;
178 bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
179 Operation *toOp) {
180 assert(fromOp->getBlock() == toOp->getBlock());
181 assert(
182 isa<MemoryEffectOpInterface>(fromOp) &&
183 cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
184 isa<MemoryEffectOpInterface>(toOp) &&
185 cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
186 Operation *nextOp = fromOp->getNextNode();
187 auto result =
188 memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
189 if (result.second) {
190 auto memEffectsCachePair = result.first->second;
191 if (memEffectsCachePair.second == nullptr) {
192 // No MemoryEffects::Write has been detected until the cached operation.
193 // Continue looking from the cached operation to toOp.
194 nextOp = memEffectsCachePair.first;
195 } else {
196 // MemoryEffects::Write has been detected before so there is no need to
197 // check further.
198 return true;
201 while (nextOp && nextOp != toOp) {
202 std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
203 getEffectsRecursively(nextOp);
204 if (!effects) {
205 // TODO: Do we need to handle other effects generically?
206 // If the operation does not implement the MemoryEffectOpInterface we
207 // conservatively assume it writes.
208 result.first->second =
209 std::make_pair(nextOp, MemoryEffects::Write::get());
210 return true;
213 for (const MemoryEffects::EffectInstance &effect : *effects) {
214 if (isa<MemoryEffects::Write>(effect.getEffect())) {
215 result.first->second = {nextOp, MemoryEffects::Write::get()};
216 return true;
219 nextOp = nextOp->getNextNode();
221 result.first->second = std::make_pair(toOp, nullptr);
222 return false;
225 /// Attempt to eliminate a redundant operation.
226 LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
227 Operation *op,
228 bool hasSSADominance) {
229 // Don't simplify terminator operations.
230 if (op->hasTrait<OpTrait::IsTerminator>())
231 return failure();
233 // If the operation is already trivially dead just add it to the erase list.
234 if (isOpTriviallyDead(op)) {
235 opsToErase.push_back(op);
236 ++numDCE;
237 return success();
240 // Don't simplify operations with regions that have multiple blocks.
241 // TODO: We need additional tests to verify that we handle such IR correctly.
242 if (!llvm::all_of(op->getRegions(), [](Region &r) {
243 return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
245 return failure();
247 // Some simple use case of operation with memory side-effect are dealt with
248 // here. Operations with no side-effect are done after.
249 if (!isMemoryEffectFree(op)) {
250 auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
251 // TODO: Only basic use case for operations with MemoryEffects::Read can be
252 // eleminated now. More work needs to be done for more complicated patterns
253 // and other side-effects.
254 if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
255 return failure();
257 // Look for an existing definition for the operation.
258 if (auto *existing = knownValues.lookup(op)) {
259 if (existing->getBlock() == op->getBlock() &&
260 !hasOtherSideEffectingOpInBetween(existing, op)) {
261 // The operation that can be deleted has been reach with no
262 // side-effecting operations in between the existing operation and
263 // this one so we can remove the duplicate.
264 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
265 return success();
268 knownValues.insert(op, op);
269 return failure();
272 // Look for an existing definition for the operation.
273 if (auto *existing = knownValues.lookup(op)) {
274 replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
275 ++numCSE;
276 return success();
279 // Otherwise, we add this operation to the known values map.
280 knownValues.insert(op, op);
281 return failure();
284 void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
285 bool hasSSADominance) {
286 for (auto &op : *bb) {
287 // Most operations don't have regions, so fast path that case.
288 if (op.getNumRegions() != 0) {
289 // If this operation is isolated above, we can't process nested regions
290 // with the given 'knownValues' map. This would cause the insertion of
291 // implicit captures in explicit capture only regions.
292 if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
293 ScopedMapTy nestedKnownValues;
294 for (auto &region : op.getRegions())
295 simplifyRegion(nestedKnownValues, region);
296 } else {
297 // Otherwise, process nested regions normally.
298 for (auto &region : op.getRegions())
299 simplifyRegion(knownValues, region);
303 // If the operation is simplified, we don't process any held regions.
304 if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance)))
305 continue;
307 // Clear the MemoryEffects cache since its usage is by block only.
308 memEffectsCache.clear();
311 void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region &region) {
312 // If the region is empty there is nothing to do.
313 if (region.empty())
314 return;
316 bool hasSSADominance = domInfo->hasSSADominance(&region);
318 // If the region only contains one block, then simplify it directly.
319 if (region.hasOneBlock()) {
320 ScopedMapTy::ScopeTy scope(knownValues);
321 simplifyBlock(knownValues, &region.front(), hasSSADominance);
322 return;
325 // If the region does not have dominanceInfo, then skip it.
326 // TODO: Regions without SSA dominance should define a different
327 // traversal order which is appropriate and can be used here.
328 if (!hasSSADominance)
329 return;
331 // Note, deque is being used here because there was significant performance
332 // gains over vector when the container becomes very large due to the
333 // specific access patterns. If/when these performance issues are no
334 // longer a problem we can change this to vector. For more information see
335 // the llvm mailing list discussion on this:
336 // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html
337 std::deque<std::unique_ptr<CFGStackNode>> stack;
339 // Process the nodes of the dom tree for this region.
340 stack.emplace_back(std::make_unique<CFGStackNode>(
341 knownValues, domInfo->getRootNode(&region)));
343 while (!stack.empty()) {
344 auto &currentNode = stack.back();
346 // Check to see if we need to process this node.
347 if (!currentNode->processed) {
348 currentNode->processed = true;
349 simplifyBlock(knownValues, currentNode->node->getBlock(),
350 hasSSADominance);
353 // Otherwise, check to see if we need to process a child node.
354 if (currentNode->childIterator != currentNode->node->end()) {
355 auto *childNode = *(currentNode->childIterator++);
356 stack.emplace_back(
357 std::make_unique<CFGStackNode>(knownValues, childNode));
358 } else {
359 // Finally, if the node and all of its children have been processed
360 // then we delete the node.
361 stack.pop_back();
366 void CSEDriver::simplify(Operation *op, bool *changed) {
367 /// Simplify all regions.
368 ScopedMapTy knownValues;
369 for (auto &region : op->getRegions())
370 simplifyRegion(knownValues, region);
372 /// Erase any operations that were marked as dead during simplification.
373 for (auto *op : opsToErase)
374 rewriter.eraseOp(op);
375 if (changed)
376 *changed = !opsToErase.empty();
378 // Note: CSE does currently not remove ops with regions, so DominanceInfo
379 // does not have to be invalidated.
382 void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter,
383 DominanceInfo &domInfo, Operation *op,
384 bool *changed) {
385 CSEDriver driver(rewriter, &domInfo);
386 driver.simplify(op, changed);
389 namespace {
390 /// CSE pass.
391 struct CSE : public impl::CSEBase<CSE> {
392 void runOnOperation() override;
394 } // namespace
396 void CSE::runOnOperation() {
397 // Simplify the IR.
398 IRRewriter rewriter(&getContext());
399 CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>());
400 bool changed = false;
401 driver.simplify(getOperation(), &changed);
403 // Set statistics.
404 numCSE = driver.getNumCSE();
405 numDCE = driver.getNumDCE();
407 // If there was no change to the IR, we mark all analyses as preserved.
408 if (!changed)
409 return markAllAnalysesPreserved();
411 // We currently don't remove region operations, so mark dominance as
412 // preserved.
413 markAnalysesPreserved<DominanceInfo, PostDominanceInfo>();
416 std::unique_ptr<Pass> mlir::createCSEPass() { return std::make_unique<CSE>(); }