1 //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the Reduction Tree Pass class. It provides a framework for
10 // the implementation of different reduction passes in the MLIR Reduce tool. It
11 // allows for custom specification of the variant generation behavior. It
12 // implements methods that define the different possible traversals of the
15 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/Reducer/Passes.h"
20 #include "mlir/Reducer/ReductionNode.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Reducer/Tester.h"
23 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
38 /// We implicitly number each operation in the region and if an operation's
39 /// number falls into rangeToKeep, we need to keep it and apply the given
40 /// rewrite patterns on it.
41 static void applyPatterns(Region
®ion
,
42 const FrozenRewritePatternSet
&patterns
,
43 ArrayRef
<ReductionNode::Range
> rangeToKeep
,
44 bool eraseOpNotInRange
) {
45 std::vector
<Operation
*> opsNotInRange
;
46 std::vector
<Operation
*> opsInRange
;
48 for (const auto &op
: enumerate(region
.getOps())) {
49 int index
= op
.index();
50 if (keepIndex
< rangeToKeep
.size() &&
51 index
== rangeToKeep
[keepIndex
].second
)
53 if (keepIndex
== rangeToKeep
.size() || index
< rangeToKeep
[keepIndex
].first
)
54 opsNotInRange
.push_back(&op
.value());
56 opsInRange
.push_back(&op
.value());
59 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60 // matching in above iteration. Besides, erase op not-in-range may end up in
61 // invalid module, so `applyOpPatternsAndFold` should come before that
63 for (Operation
*op
: opsInRange
) {
64 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65 // because we don't have expectation this reduction will be success or not.
66 GreedyRewriteConfig config
;
67 config
.strictMode
= GreedyRewriteStrictness::ExistingOps
;
68 (void)applyOpPatternsAndFold(op
, patterns
, config
);
71 if (eraseOpNotInRange
)
72 for (Operation
*op
: opsNotInRange
) {
78 /// We will apply the reducer patterns to the operations in the ranges specified
79 /// by ReductionNode. Note that we are not able to remove an operation without
80 /// replacing it with another valid operation. However, The validity of module
81 /// reduction is based on the Tester provided by the user and that means certain
82 /// invalid module is still interested by the use. Thus we provide an
83 /// alternative way to remove operations, which is using `eraseOpNotInRange` to
84 /// erase the operations not in the range specified by ReductionNode.
85 template <typename IteratorType
>
86 static LogicalResult
findOptimal(ModuleOp module
, Region
®ion
,
87 const FrozenRewritePatternSet
&patterns
,
88 const Tester
&test
, bool eraseOpNotInRange
) {
89 std::pair
<Tester::Interestingness
, size_t> initStatus
=
90 test
.isInteresting(module
);
91 // While exploring the reduction tree, we always branch from an interesting
92 // node. Thus the root node must be interesting.
93 if (initStatus
.first
!= Tester::Interestingness::True
)
94 return module
.emitWarning() << "uninterested module will not be reduced";
96 llvm::SpecificBumpPtrAllocator
<ReductionNode
> allocator
;
98 std::vector
<ReductionNode::Range
> ranges
{
99 {0, std::distance(region
.op_begin(), region
.op_end())}};
101 ReductionNode
*root
= allocator
.Allocate();
102 new (root
) ReductionNode(nullptr, ranges
, allocator
);
103 // Duplicate the module for root node and locate the region in the copy.
104 if (failed(root
->initialize(module
, region
)))
105 llvm_unreachable("unexpected initialization failure");
106 root
->update(initStatus
);
108 ReductionNode
*smallestNode
= root
;
109 IteratorType
iter(root
);
111 while (iter
!= IteratorType::end()) {
112 ReductionNode
¤tNode
= *iter
;
113 Region
&curRegion
= currentNode
.getRegion();
115 applyPatterns(curRegion
, patterns
, currentNode
.getRanges(),
117 currentNode
.update(test
.isInteresting(currentNode
.getModule()));
119 if (currentNode
.isInteresting() == Tester::Interestingness::True
&&
120 currentNode
.getSize() < smallestNode
->getSize())
121 smallestNode
= ¤tNode
;
126 // At here, we have found an optimal path to reduce the given region. Retrieve
127 // the path and apply the reducer to it.
128 SmallVector
<ReductionNode
*> trace
;
129 ReductionNode
*curNode
= smallestNode
;
130 trace
.push_back(curNode
);
131 while (curNode
!= root
) {
132 curNode
= curNode
->getParent();
133 trace
.push_back(curNode
);
136 // Reduce the region through the optimal path.
137 while (!trace
.empty()) {
138 ReductionNode
*top
= trace
.pop_back_val();
139 applyPatterns(region
, patterns
, top
->getStartRanges(), eraseOpNotInRange
);
142 if (test
.isInteresting(module
).first
!= Tester::Interestingness::True
)
143 llvm::report_fatal_error("Reduced module is not interesting");
144 if (test
.isInteresting(module
).second
!= smallestNode
->getSize())
145 llvm::report_fatal_error(
146 "Reduced module doesn't have consistent size with smallestNode");
150 template <typename IteratorType
>
151 static LogicalResult
findOptimal(ModuleOp module
, Region
®ion
,
152 const FrozenRewritePatternSet
&patterns
,
153 const Tester
&test
) {
154 // We separate the reduction process into 2 steps, the first one is to erase
155 // redundant operations and the second one is to apply the reducer patterns.
157 // In the first phase, we don't apply any patterns so that we only select the
158 // range of operations to keep to the module stay interesting.
159 if (failed(findOptimal
<IteratorType
>(module
, region
, /*patterns=*/{}, test
,
160 /*eraseOpNotInRange=*/true)))
162 // In the second phase, we suppose that no operation is redundant, so we try
163 // to rewrite the operation into simpler form.
164 return findOptimal
<IteratorType
>(module
, region
, patterns
, test
,
165 /*eraseOpNotInRange=*/false);
170 //===----------------------------------------------------------------------===//
171 // Reduction Pattern Interface Collection
172 //===----------------------------------------------------------------------===//
174 class ReductionPatternInterfaceCollection
175 : public DialectInterfaceCollection
<DialectReductionPatternInterface
> {
179 // Collect the reduce patterns defined by each dialect.
180 void populateReductionPatterns(RewritePatternSet
&pattern
) const {
181 for (const DialectReductionPatternInterface
&interface
: *this)
182 interface
.populateReductionPatterns(pattern
);
186 //===----------------------------------------------------------------------===//
188 //===----------------------------------------------------------------------===//
190 /// This class defines the Reduction Tree Pass. It provides a framework to
191 /// to implement a reduction pass using a tree structure to keep track of the
192 /// generated reduced variants.
193 class ReductionTreePass
: public impl::ReductionTreeBase
<ReductionTreePass
> {
195 ReductionTreePass() = default;
196 ReductionTreePass(const ReductionTreePass
&pass
) = default;
198 LogicalResult
initialize(MLIRContext
*context
) override
;
200 /// Runs the pass instance in the pass pipeline.
201 void runOnOperation() override
;
204 LogicalResult
reduceOp(ModuleOp module
, Region
®ion
);
206 FrozenRewritePatternSet reducerPatterns
;
211 LogicalResult
ReductionTreePass::initialize(MLIRContext
*context
) {
212 RewritePatternSet
patterns(context
);
213 ReductionPatternInterfaceCollection
reducePatternCollection(context
);
214 reducePatternCollection
.populateReductionPatterns(patterns
);
215 reducerPatterns
= std::move(patterns
);
219 void ReductionTreePass::runOnOperation() {
220 Operation
*topOperation
= getOperation();
221 while (topOperation
->getParentOp() != nullptr)
222 topOperation
= topOperation
->getParentOp();
223 ModuleOp module
= dyn_cast
<ModuleOp
>(topOperation
);
225 emitError(getOperation()->getLoc())
226 << "top-level op must be 'builtin.module'";
227 return signalPassFailure();
230 SmallVector
<Operation
*, 8> workList
;
231 workList
.push_back(getOperation());
234 Operation
*op
= workList
.pop_back_val();
236 for (Region
®ion
: op
->getRegions())
238 if (failed(reduceOp(module
, region
)))
239 return signalPassFailure();
241 for (Region
®ion
: op
->getRegions())
242 for (Operation
&op
: region
.getOps())
243 if (op
.getNumRegions() != 0)
244 workList
.push_back(&op
);
245 } while (!workList
.empty());
248 LogicalResult
ReductionTreePass::reduceOp(ModuleOp module
, Region
®ion
) {
249 Tester
test(testerName
, testerArgs
);
250 switch (traversalModeId
) {
251 case TraversalMode::SinglePath
:
252 return findOptimal
<ReductionNode::iterator
<TraversalMode::SinglePath
>>(
253 module
, region
, reducerPatterns
, test
);
255 return module
.emitError() << "unsupported traversal mode detected";
259 std::unique_ptr
<Pass
> mlir::createReductionTreePass() {
260 return std::make_unique
<ReductionTreePass
>();