IR: de-duplicate two CmpInst routines (NFC) (#116866)
[llvm-project.git] / mlir / lib / Transforms / Utils / CommutativityUtils.cpp
blob5ba6e4747cb57f245fd992ba638683283ac514cf
1 //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a commutativity utility pattern and a function to
10 // populate this pattern. The function is intended to be used inside passes to
11 // simplify the matching of commutative operations by fixing the order of their
12 // operands.
14 //===----------------------------------------------------------------------===//
16 #include "mlir/Transforms/CommutativityUtils.h"
18 #include <queue>
20 using namespace mlir;
22 /// The possible "types" of ancestors. Here, an ancestor is an op or a block
23 /// argument present in the backward slice of a value.
24 enum AncestorType {
25 /// Pertains to a block argument.
26 BLOCK_ARGUMENT,
28 /// Pertains to a non-constant-like op.
29 NON_CONSTANT_OP,
31 /// Pertains to a constant-like op.
32 CONSTANT_OP
35 /// Stores the "key" associated with an ancestor.
36 struct AncestorKey {
37 /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
38 /// the ancestor.
39 AncestorType type;
41 /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42 /// `CONSTANT_OP`. Else, holds "".
43 StringRef opName;
45 /// Constructor for `AncestorKey`.
46 AncestorKey(Operation *op) {
47 if (!op) {
48 type = BLOCK_ARGUMENT;
49 } else {
50 type =
51 op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP;
52 opName = op->getName().getStringRef();
56 /// Overloaded operator `<` for `AncestorKey`.
57 ///
58 /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59 /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60 /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61 /// ones are the ones with smaller op names (lexicographically).
62 ///
63 /// TODO: Include other information like attributes, value type, etc., to
64 /// enhance this comparison. For example, currently this comparison doesn't
65 /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66 /// `addi (in i64)`. Such an enhancement should only be done if the need
67 /// arises.
68 bool operator<(const AncestorKey &key) const {
69 return std::tie(type, opName) < std::tie(key.type, key.opName);
73 /// Stores a commutative operand along with its BFS traversal information.
74 struct CommutativeOperand {
75 /// Stores the operand.
76 Value operand;
78 /// Stores the queue of ancestors of the operand's BFS traversal at a
79 /// particular point in time.
80 std::queue<Operation *> ancestorQueue;
82 /// Stores the list of ancestors that have been visited by the BFS traversal
83 /// at a particular point in time.
84 DenseSet<Operation *> visitedAncestors;
86 /// Stores the operand's "key". This "key" is defined as a list of the
87 /// "AncestorKeys" associated with the ancestors of this operand, in a
88 /// breadth-first order.
89 ///
90 /// So, if an operand, say `A`, was produced as follows:
91 ///
92 /// `<block argument>` `<block argument>`
93 /// \ /
94 /// \ /
95 /// `arith.subi` `arith.constant`
96 /// \ /
97 /// `arith.addi`
98 /// |
99 /// returns `A`
101 /// Then, the ancestors of `A`, in the breadth-first order are:
102 /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103 /// `<block argument>`.
105 /// Thus, the "key" associated with operand `A` is:
106 /// {
107 /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108 /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109 /// {type: `CONSTANT_OP`, opName: "arith.constant"},
110 /// {type: `BLOCK_ARGUMENT`, opName: ""},
111 /// {type: `BLOCK_ARGUMENT`, opName: ""}
112 /// }
113 SmallVector<AncestorKey, 4> key;
115 /// Push an ancestor into the operand's BFS information structure. This
116 /// entails it being pushed into the queue (always) and inserted into the
117 /// "visited ancestors" list (iff it is an op rather than a block argument).
118 void pushAncestor(Operation *op) {
119 ancestorQueue.push(op);
120 if (op)
121 visitedAncestors.insert(op);
124 /// Refresh the key.
126 /// Refreshing a key entails making it up-to-date with the operand's BFS
127 /// traversal that has happened till that point in time, i.e, appending the
128 /// existing key with the front ancestor's "AncestorKey". Note that a key
129 /// directly reflects the BFS and thus needs to be refreshed during the
130 /// progression of the traversal.
131 void refreshKey() {
132 if (ancestorQueue.empty())
133 return;
135 Operation *frontAncestor = ancestorQueue.front();
136 AncestorKey frontAncestorKey(frontAncestor);
137 key.push_back(frontAncestorKey);
140 /// Pop the front ancestor, if any, from the queue and then push its adjacent
141 /// unvisited ancestors, if any, to the queue (this is the main body of the
142 /// BFS algorithm).
143 void popFrontAndPushAdjacentUnvisitedAncestors() {
144 if (ancestorQueue.empty())
145 return;
146 Operation *frontAncestor = ancestorQueue.front();
147 ancestorQueue.pop();
148 if (!frontAncestor)
149 return;
150 for (Value operand : frontAncestor->getOperands()) {
151 Operation *operandDefOp = operand.getDefiningOp();
152 if (!operandDefOp || !visitedAncestors.contains(operandDefOp))
153 pushAncestor(operandDefOp);
158 /// Sorts the operands of `op` in ascending order of the "key" associated with
159 /// each operand iff `op` is commutative. This is a stable sort.
161 /// After the application of this pattern, since the commutative operands now
162 /// have a deterministic order in which they occur in an op, the matching of
163 /// large DAGs becomes much simpler, i.e., requires much less number of checks
164 /// to be written by a user in her/his pattern matching function.
166 /// Some examples of such a sorting:
168 /// Assume that the sorting is being applied to `foo.commutative`, which is a
169 /// commutative op.
171 /// Example 1:
173 /// %1 = foo.const 0
174 /// %2 = foo.mul <block argument>, <block argument>
175 /// %3 = foo.commutative %1, %2
177 /// Here,
178 /// 1. The key associated with %1 is:
179 /// `{
180 /// {CONSTANT_OP, "foo.const"}
181 /// }`
182 /// 2. The key associated with %2 is:
183 /// `{
184 /// {NON_CONSTANT_OP, "foo.mul"},
185 /// {BLOCK_ARGUMENT, ""},
186 /// {BLOCK_ARGUMENT, ""}
187 /// }`
189 /// The key of %2 < the key of %1
190 /// Thus, the sorted `foo.commutative` is:
191 /// %3 = foo.commutative %2, %1
193 /// Example 2:
195 /// %1 = foo.const 0
196 /// %2 = foo.mul <block argument>, <block argument>
197 /// %3 = foo.mul %2, %1
198 /// %4 = foo.add %2, %1
199 /// %5 = foo.commutative %1, %2, %3, %4
201 /// Here,
202 /// 1. The key associated with %1 is:
203 /// `{
204 /// {CONSTANT_OP, "foo.const"}
205 /// }`
206 /// 2. The key associated with %2 is:
207 /// `{
208 /// {NON_CONSTANT_OP, "foo.mul"},
209 /// {BLOCK_ARGUMENT, ""}
210 /// }`
211 /// 3. The key associated with %3 is:
212 /// `{
213 /// {NON_CONSTANT_OP, "foo.mul"},
214 /// {NON_CONSTANT_OP, "foo.mul"},
215 /// {CONSTANT_OP, "foo.const"},
216 /// {BLOCK_ARGUMENT, ""},
217 /// {BLOCK_ARGUMENT, ""}
218 /// }`
219 /// 4. The key associated with %4 is:
220 /// `{
221 /// {NON_CONSTANT_OP, "foo.add"},
222 /// {NON_CONSTANT_OP, "foo.mul"},
223 /// {CONSTANT_OP, "foo.const"},
224 /// {BLOCK_ARGUMENT, ""},
225 /// {BLOCK_ARGUMENT, ""}
226 /// }`
228 /// Thus, the sorted `foo.commutative` is:
229 /// %5 = foo.commutative %4, %3, %2, %1
230 class SortCommutativeOperands : public RewritePattern {
231 public:
232 SortCommutativeOperands(MLIRContext *context)
233 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {}
234 LogicalResult matchAndRewrite(Operation *op,
235 PatternRewriter &rewriter) const override {
236 // Custom comparator for two commutative operands, which returns true iff
237 // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
238 // i.e.,
239 // 1. In the first unequal pair of corresponding AncestorKeys, the
240 // AncestorKey in `constCommOperandA` is smaller, or,
241 // 2. Both the AncestorKeys in every pair are the same and the size of
242 // `constCommOperandA`'s "key" is smaller.
243 auto commutativeOperandComparator =
244 [](const std::unique_ptr<CommutativeOperand> &constCommOperandA,
245 const std::unique_ptr<CommutativeOperand> &constCommOperandB) {
246 if (constCommOperandA->operand == constCommOperandB->operand)
247 return false;
249 auto &commOperandA =
250 const_cast<std::unique_ptr<CommutativeOperand> &>(
251 constCommOperandA);
252 auto &commOperandB =
253 const_cast<std::unique_ptr<CommutativeOperand> &>(
254 constCommOperandB);
256 // Iteratively perform the BFS's of both operands until an order among
257 // them can be determined.
258 unsigned keyIndex = 0;
259 while (true) {
260 if (commOperandA->key.size() <= keyIndex) {
261 if (commOperandA->ancestorQueue.empty())
262 return true;
263 commOperandA->popFrontAndPushAdjacentUnvisitedAncestors();
264 commOperandA->refreshKey();
266 if (commOperandB->key.size() <= keyIndex) {
267 if (commOperandB->ancestorQueue.empty())
268 return false;
269 commOperandB->popFrontAndPushAdjacentUnvisitedAncestors();
270 commOperandB->refreshKey();
272 if (commOperandA->ancestorQueue.empty() ||
273 commOperandB->ancestorQueue.empty())
274 return commOperandA->key.size() < commOperandB->key.size();
275 if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex])
276 return true;
277 if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex])
278 return false;
279 keyIndex++;
283 // If `op` is not commutative, do nothing.
284 if (!op->hasTrait<OpTrait::IsCommutative>())
285 return failure();
287 // Populate the list of commutative operands.
288 SmallVector<Value, 2> operands = op->getOperands();
289 SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands;
290 for (Value operand : operands) {
291 std::unique_ptr<CommutativeOperand> commOperand =
292 std::make_unique<CommutativeOperand>();
293 commOperand->operand = operand;
294 commOperand->pushAncestor(operand.getDefiningOp());
295 commOperand->refreshKey();
296 commOperands.push_back(std::move(commOperand));
299 // Sort the operands.
300 std::stable_sort(commOperands.begin(), commOperands.end(),
301 commutativeOperandComparator);
302 SmallVector<Value, 2> sortedOperands;
303 for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands)
304 sortedOperands.push_back(commOperand->operand);
305 if (sortedOperands == operands)
306 return failure();
307 rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
308 return success();
312 void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) {
313 patterns.add<SortCommutativeOperands>(patterns.getContext());