1 //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a commutativity utility pattern and a function to
10 // populate this pattern. The function is intended to be used inside passes to
11 // simplify the matching of commutative operations by fixing the order of their
14 //===----------------------------------------------------------------------===//
16 #include "mlir/Transforms/CommutativityUtils.h"
22 /// The possible "types" of ancestors. Here, an ancestor is an op or a block
23 /// argument present in the backward slice of a value.
25 /// Pertains to a block argument.
28 /// Pertains to a non-constant-like op.
31 /// Pertains to a constant-like op.
35 /// Stores the "key" associated with an ancestor.
37 /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on
41 /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or
42 /// `CONSTANT_OP`. Else, holds "".
45 /// Constructor for `AncestorKey`.
46 AncestorKey(Operation
*op
) {
48 type
= BLOCK_ARGUMENT
;
51 op
->hasTrait
<OpTrait::ConstantLike
>() ? CONSTANT_OP
: NON_CONSTANT_OP
;
52 opName
= op
->getName().getStringRef();
56 /// Overloaded operator `<` for `AncestorKey`.
58 /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those
59 /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in
60 /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller
61 /// ones are the ones with smaller op names (lexicographically).
63 /// TODO: Include other information like attributes, value type, etc., to
64 /// enhance this comparison. For example, currently this comparison doesn't
65 /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and
66 /// `addi (in i64)`. Such an enhancement should only be done if the need
68 bool operator<(const AncestorKey
&key
) const {
69 return std::tie(type
, opName
) < std::tie(key
.type
, key
.opName
);
73 /// Stores a commutative operand along with its BFS traversal information.
74 struct CommutativeOperand
{
75 /// Stores the operand.
78 /// Stores the queue of ancestors of the operand's BFS traversal at a
79 /// particular point in time.
80 std::queue
<Operation
*> ancestorQueue
;
82 /// Stores the list of ancestors that have been visited by the BFS traversal
83 /// at a particular point in time.
84 DenseSet
<Operation
*> visitedAncestors
;
86 /// Stores the operand's "key". This "key" is defined as a list of the
87 /// "AncestorKeys" associated with the ancestors of this operand, in a
88 /// breadth-first order.
90 /// So, if an operand, say `A`, was produced as follows:
92 /// `<block argument>` `<block argument>`
95 /// `arith.subi` `arith.constant`
101 /// Then, the ancestors of `A`, in the breadth-first order are:
102 /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and
103 /// `<block argument>`.
105 /// Thus, the "key" associated with operand `A` is:
107 /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"},
108 /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"},
109 /// {type: `CONSTANT_OP`, opName: "arith.constant"},
110 /// {type: `BLOCK_ARGUMENT`, opName: ""},
111 /// {type: `BLOCK_ARGUMENT`, opName: ""}
113 SmallVector
<AncestorKey
, 4> key
;
115 /// Push an ancestor into the operand's BFS information structure. This
116 /// entails it being pushed into the queue (always) and inserted into the
117 /// "visited ancestors" list (iff it is an op rather than a block argument).
118 void pushAncestor(Operation
*op
) {
119 ancestorQueue
.push(op
);
121 visitedAncestors
.insert(op
);
126 /// Refreshing a key entails making it up-to-date with the operand's BFS
127 /// traversal that has happened till that point in time, i.e, appending the
128 /// existing key with the front ancestor's "AncestorKey". Note that a key
129 /// directly reflects the BFS and thus needs to be refreshed during the
130 /// progression of the traversal.
132 if (ancestorQueue
.empty())
135 Operation
*frontAncestor
= ancestorQueue
.front();
136 AncestorKey
frontAncestorKey(frontAncestor
);
137 key
.push_back(frontAncestorKey
);
140 /// Pop the front ancestor, if any, from the queue and then push its adjacent
141 /// unvisited ancestors, if any, to the queue (this is the main body of the
143 void popFrontAndPushAdjacentUnvisitedAncestors() {
144 if (ancestorQueue
.empty())
146 Operation
*frontAncestor
= ancestorQueue
.front();
150 for (Value operand
: frontAncestor
->getOperands()) {
151 Operation
*operandDefOp
= operand
.getDefiningOp();
152 if (!operandDefOp
|| !visitedAncestors
.contains(operandDefOp
))
153 pushAncestor(operandDefOp
);
158 /// Sorts the operands of `op` in ascending order of the "key" associated with
159 /// each operand iff `op` is commutative. This is a stable sort.
161 /// After the application of this pattern, since the commutative operands now
162 /// have a deterministic order in which they occur in an op, the matching of
163 /// large DAGs becomes much simpler, i.e., requires much less number of checks
164 /// to be written by a user in her/his pattern matching function.
166 /// Some examples of such a sorting:
168 /// Assume that the sorting is being applied to `foo.commutative`, which is a
174 /// %2 = foo.mul <block argument>, <block argument>
175 /// %3 = foo.commutative %1, %2
178 /// 1. The key associated with %1 is:
180 /// {CONSTANT_OP, "foo.const"}
182 /// 2. The key associated with %2 is:
184 /// {NON_CONSTANT_OP, "foo.mul"},
185 /// {BLOCK_ARGUMENT, ""},
186 /// {BLOCK_ARGUMENT, ""}
189 /// The key of %2 < the key of %1
190 /// Thus, the sorted `foo.commutative` is:
191 /// %3 = foo.commutative %2, %1
196 /// %2 = foo.mul <block argument>, <block argument>
197 /// %3 = foo.mul %2, %1
198 /// %4 = foo.add %2, %1
199 /// %5 = foo.commutative %1, %2, %3, %4
202 /// 1. The key associated with %1 is:
204 /// {CONSTANT_OP, "foo.const"}
206 /// 2. The key associated with %2 is:
208 /// {NON_CONSTANT_OP, "foo.mul"},
209 /// {BLOCK_ARGUMENT, ""}
211 /// 3. The key associated with %3 is:
213 /// {NON_CONSTANT_OP, "foo.mul"},
214 /// {NON_CONSTANT_OP, "foo.mul"},
215 /// {CONSTANT_OP, "foo.const"},
216 /// {BLOCK_ARGUMENT, ""},
217 /// {BLOCK_ARGUMENT, ""}
219 /// 4. The key associated with %4 is:
221 /// {NON_CONSTANT_OP, "foo.add"},
222 /// {NON_CONSTANT_OP, "foo.mul"},
223 /// {CONSTANT_OP, "foo.const"},
224 /// {BLOCK_ARGUMENT, ""},
225 /// {BLOCK_ARGUMENT, ""}
228 /// Thus, the sorted `foo.commutative` is:
229 /// %5 = foo.commutative %4, %3, %2, %1
230 class SortCommutativeOperands
: public RewritePattern
{
232 SortCommutativeOperands(MLIRContext
*context
)
233 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context
) {}
234 LogicalResult
matchAndRewrite(Operation
*op
,
235 PatternRewriter
&rewriter
) const override
{
236 // Custom comparator for two commutative operands, which returns true iff
237 // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`,
239 // 1. In the first unequal pair of corresponding AncestorKeys, the
240 // AncestorKey in `constCommOperandA` is smaller, or,
241 // 2. Both the AncestorKeys in every pair are the same and the size of
242 // `constCommOperandA`'s "key" is smaller.
243 auto commutativeOperandComparator
=
244 [](const std::unique_ptr
<CommutativeOperand
> &constCommOperandA
,
245 const std::unique_ptr
<CommutativeOperand
> &constCommOperandB
) {
246 if (constCommOperandA
->operand
== constCommOperandB
->operand
)
250 const_cast<std::unique_ptr
<CommutativeOperand
> &>(
253 const_cast<std::unique_ptr
<CommutativeOperand
> &>(
256 // Iteratively perform the BFS's of both operands until an order among
257 // them can be determined.
258 unsigned keyIndex
= 0;
260 if (commOperandA
->key
.size() <= keyIndex
) {
261 if (commOperandA
->ancestorQueue
.empty())
263 commOperandA
->popFrontAndPushAdjacentUnvisitedAncestors();
264 commOperandA
->refreshKey();
266 if (commOperandB
->key
.size() <= keyIndex
) {
267 if (commOperandB
->ancestorQueue
.empty())
269 commOperandB
->popFrontAndPushAdjacentUnvisitedAncestors();
270 commOperandB
->refreshKey();
272 if (commOperandA
->ancestorQueue
.empty() ||
273 commOperandB
->ancestorQueue
.empty())
274 return commOperandA
->key
.size() < commOperandB
->key
.size();
275 if (commOperandA
->key
[keyIndex
] < commOperandB
->key
[keyIndex
])
277 if (commOperandB
->key
[keyIndex
] < commOperandA
->key
[keyIndex
])
283 // If `op` is not commutative, do nothing.
284 if (!op
->hasTrait
<OpTrait::IsCommutative
>())
287 // Populate the list of commutative operands.
288 SmallVector
<Value
, 2> operands
= op
->getOperands();
289 SmallVector
<std::unique_ptr
<CommutativeOperand
>, 2> commOperands
;
290 for (Value operand
: operands
) {
291 std::unique_ptr
<CommutativeOperand
> commOperand
=
292 std::make_unique
<CommutativeOperand
>();
293 commOperand
->operand
= operand
;
294 commOperand
->pushAncestor(operand
.getDefiningOp());
295 commOperand
->refreshKey();
296 commOperands
.push_back(std::move(commOperand
));
299 // Sort the operands.
300 std::stable_sort(commOperands
.begin(), commOperands
.end(),
301 commutativeOperandComparator
);
302 SmallVector
<Value
, 2> sortedOperands
;
303 for (const std::unique_ptr
<CommutativeOperand
> &commOperand
: commOperands
)
304 sortedOperands
.push_back(commOperand
->operand
);
305 if (sortedOperands
== operands
)
307 rewriter
.modifyOpInPlace(op
, [&] { op
->setOperands(sortedOperands
); });
312 void mlir::populateCommutativityUtilsPatterns(RewritePatternSet
&patterns
) {
313 patterns
.add
<SortCommutativeOperands
>(patterns
.getContext());