Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Reducer / ReductionTreePass.cpp
blobb00045a3a41b7a1de3f20f5fa842882849bf80fc
1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
13 // reduction tree.
15 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
20 #include "mlir/Reducer/ReductionNode.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Reducer/Tester.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
31 namespace mlir {
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
34 } // namespace mlir
36 using namespace mlir;
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region &region,
42 const FrozenRewritePatternSet &patterns,
43 ArrayRef<ReductionNode::Range> rangeToKeep,
44 bool eraseOpNotInRange) {
45 std::vector<Operation *> opsNotInRange;
46 std::vector<Operation *> opsInRange;
47 size_t keepIndex = 0;
48 for (const auto &op : enumerate(region.getOps())) {
49 int index = op.index();
50 if (keepIndex < rangeToKeep.size() &&
51 index == rangeToKeep[keepIndex].second)
52 ++keepIndex;
53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54 opsNotInRange.push_back(&op.value());
55 else
56 opsInRange.push_back(&op.value());
59 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60 // matching in above iteration. Besides, erase op not-in-range may end up in
61 // invalid module, so `applyOpPatternsAndFold` should come before that
62 // transform.
63 for (Operation *op : opsInRange) {
64 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65 // because we don't have expectation this reduction will be success or not.
66 GreedyRewriteConfig config;
67 config.strictMode = GreedyRewriteStrictness::ExistingOps;
68 (void)applyOpPatternsAndFold(op, patterns, config);
71 if (eraseOpNotInRange)
72 for (Operation *op : opsNotInRange) {
73 op->dropAllUses();
74 op->erase();
78 /// We will apply the reducer patterns to the operations in the ranges specified
79 /// by ReductionNode. Note that we are not able to remove an operation without
80 /// replacing it with another valid operation. However, The validity of module
81 /// reduction is based on the Tester provided by the user and that means certain
82 /// invalid module is still interested by the use. Thus we provide an
83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84 /// erase the operations not in the range specified by ReductionNode.
85 template <typename IteratorType>
86 static LogicalResult findOptimal(ModuleOp module, Region &region,
87 const FrozenRewritePatternSet &patterns,
88 const Tester &test, bool eraseOpNotInRange) {
89 std::pair<Tester::Interestingness, size_t> initStatus =
90 test.isInteresting(module);
91 // While exploring the reduction tree, we always branch from an interesting
92 // node. Thus the root node must be interesting.
93 if (initStatus.first != Tester::Interestingness::True)
94 return module.emitWarning() << "uninterested module will not be reduced";
96 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
98 std::vector<ReductionNode::Range> ranges{
99 {0, std::distance(region.op_begin(), region.op_end())}};
101 ReductionNode *root = allocator.Allocate();
102 new (root) ReductionNode(nullptr, ranges, allocator);
103 // Duplicate the module for root node and locate the region in the copy.
104 if (failed(root->initialize(module, region)))
105 llvm_unreachable("unexpected initialization failure");
106 root->update(initStatus);
108 ReductionNode *smallestNode = root;
109 IteratorType iter(root);
111 while (iter != IteratorType::end()) {
112 ReductionNode &currentNode = *iter;
113 Region &curRegion = currentNode.getRegion();
115 applyPatterns(curRegion, patterns, currentNode.getRanges(),
116 eraseOpNotInRange);
117 currentNode.update(test.isInteresting(currentNode.getModule()));
119 if (currentNode.isInteresting() == Tester::Interestingness::True &&
120 currentNode.getSize() < smallestNode->getSize())
121 smallestNode = &currentNode;
123 ++iter;
126 // At here, we have found an optimal path to reduce the given region. Retrieve
127 // the path and apply the reducer to it.
128 SmallVector<ReductionNode *> trace;
129 ReductionNode *curNode = smallestNode;
130 trace.push_back(curNode);
131 while (curNode != root) {
132 curNode = curNode->getParent();
133 trace.push_back(curNode);
136 // Reduce the region through the optimal path.
137 while (!trace.empty()) {
138 ReductionNode *top = trace.pop_back_val();
139 applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
142 if (test.isInteresting(module).first != Tester::Interestingness::True)
143 llvm::report_fatal_error("Reduced module is not interesting");
144 if (test.isInteresting(module).second != smallestNode->getSize())
145 llvm::report_fatal_error(
146 "Reduced module doesn't have consistent size with smallestNode");
147 return success();
150 template <typename IteratorType>
151 static LogicalResult findOptimal(ModuleOp module, Region &region,
152 const FrozenRewritePatternSet &patterns,
153 const Tester &test) {
154 // We separate the reduction process into 2 steps, the first one is to erase
155 // redundant operations and the second one is to apply the reducer patterns.
157 // In the first phase, we don't apply any patterns so that we only select the
158 // range of operations to keep to the module stay interesting.
159 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
160 /*eraseOpNotInRange=*/true)))
161 return failure();
162 // In the second phase, we suppose that no operation is redundant, so we try
163 // to rewrite the operation into simpler form.
164 return findOptimal<IteratorType>(module, region, patterns, test,
165 /*eraseOpNotInRange=*/false);
168 namespace {
170 //===----------------------------------------------------------------------===//
171 // Reduction Pattern Interface Collection
172 //===----------------------------------------------------------------------===//
174 class ReductionPatternInterfaceCollection
175 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176 public:
177 using Base::Base;
179 // Collect the reduce patterns defined by each dialect.
180 void populateReductionPatterns(RewritePatternSet &pattern) const {
181 for (const DialectReductionPatternInterface &interface : *this)
182 interface.populateReductionPatterns(pattern);
186 //===----------------------------------------------------------------------===//
187 // ReductionTreePass
188 //===----------------------------------------------------------------------===//
190 /// This class defines the Reduction Tree Pass. It provides a framework to
191 /// to implement a reduction pass using a tree structure to keep track of the
192 /// generated reduced variants.
193 class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194 public:
195 ReductionTreePass() = default;
196 ReductionTreePass(const ReductionTreePass &pass) = default;
198 LogicalResult initialize(MLIRContext *context) override;
200 /// Runs the pass instance in the pass pipeline.
201 void runOnOperation() override;
203 private:
204 LogicalResult reduceOp(ModuleOp module, Region &region);
206 FrozenRewritePatternSet reducerPatterns;
209 } // namespace
211 LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212 RewritePatternSet patterns(context);
213 ReductionPatternInterfaceCollection reducePatternCollection(context);
214 reducePatternCollection.populateReductionPatterns(patterns);
215 reducerPatterns = std::move(patterns);
216 return success();
219 void ReductionTreePass::runOnOperation() {
220 Operation *topOperation = getOperation();
221 while (topOperation->getParentOp() != nullptr)
222 topOperation = topOperation->getParentOp();
223 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224 if (!module) {
225 emitError(getOperation()->getLoc())
226 << "top-level op must be 'builtin.module'";
227 return signalPassFailure();
230 SmallVector<Operation *, 8> workList;
231 workList.push_back(getOperation());
233 do {
234 Operation *op = workList.pop_back_val();
236 for (Region &region : op->getRegions())
237 if (!region.empty())
238 if (failed(reduceOp(module, region)))
239 return signalPassFailure();
241 for (Region &region : op->getRegions())
242 for (Operation &op : region.getOps())
243 if (op.getNumRegions() != 0)
244 workList.push_back(&op);
245 } while (!workList.empty());
248 LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249 Tester test(testerName, testerArgs);
250 switch (traversalModeId) {
251 case TraversalMode::SinglePath:
252 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253 module, region, reducerPatterns, test);
254 default:
255 return module.emitError() << "unsupported traversal mode detected";
259 std::unique_ptr<Pass> mlir::createReductionTreePass() {
260 return std::make_unique<ReductionTreePass>();