1 //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "PredicateTree.h"
10 #include "RootOrdering.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/Interfaces/InferTypeOpInterface.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "pdl-predicate-tree"
26 using namespace mlir::pdl_to_pdl_interp
;
28 //===----------------------------------------------------------------------===//
29 // Predicate List Building
30 //===----------------------------------------------------------------------===//
32 static void getTreePredicates(std::vector
<PositionalPredicate
> &predList
,
33 Value val
, PredicateBuilder
&builder
,
34 DenseMap
<Value
, Position
*> &inputs
,
37 /// Compares the depths of two positions.
38 static bool comparePosDepth(Position
*lhs
, Position
*rhs
) {
39 return lhs
->getOperationDepth() < rhs
->getOperationDepth();
42 /// Returns the number of non-range elements within `values`.
43 static unsigned getNumNonRangeValues(ValueRange values
) {
44 return llvm::count_if(values
.getTypes(),
45 [](Type type
) { return !isa
<pdl::RangeType
>(type
); });
48 static void getTreePredicates(std::vector
<PositionalPredicate
> &predList
,
49 Value val
, PredicateBuilder
&builder
,
50 DenseMap
<Value
, Position
*> &inputs
,
51 AttributePosition
*pos
) {
52 assert(isa
<pdl::AttributeType
>(val
.getType()) && "expected attribute type");
53 predList
.emplace_back(pos
, builder
.getIsNotNull());
55 if (auto attr
= dyn_cast
<pdl::AttributeOp
>(val
.getDefiningOp())) {
56 // If the attribute has a type or value, add a constraint.
57 if (Value type
= attr
.getValueType())
58 getTreePredicates(predList
, type
, builder
, inputs
, builder
.getType(pos
));
59 else if (Attribute value
= attr
.getValueAttr())
60 predList
.emplace_back(pos
, builder
.getAttributeConstraint(value
));
64 /// Collect all of the predicates for the given operand position.
65 static void getOperandTreePredicates(std::vector
<PositionalPredicate
> &predList
,
66 Value val
, PredicateBuilder
&builder
,
67 DenseMap
<Value
, Position
*> &inputs
,
69 Type valueType
= val
.getType();
70 bool isVariadic
= isa
<pdl::RangeType
>(valueType
);
72 // If this is a typed operand, add a type constraint.
73 TypeSwitch
<Operation
*>(val
.getDefiningOp())
74 .Case
<pdl::OperandOp
, pdl::OperandsOp
>([&](auto op
) {
75 // Prevent traversal into a null value if the operand has a proper
77 if (std::is_same
<pdl::OperandOp
, decltype(op
)>::value
||
78 cast
<OperandGroupPosition
>(pos
)->getOperandGroupNumber())
79 predList
.emplace_back(pos
, builder
.getIsNotNull());
81 if (Value type
= op
.getValueType())
82 getTreePredicates(predList
, type
, builder
, inputs
,
83 builder
.getType(pos
));
85 .Case
<pdl::ResultOp
, pdl::ResultsOp
>([&](auto op
) {
86 std::optional
<unsigned> index
= op
.getIndex();
88 // Prevent traversal into a null value if the result has a proper index.
90 predList
.emplace_back(pos
, builder
.getIsNotNull());
92 // Get the parent operation of this operand.
93 OperationPosition
*parentPos
= builder
.getOperandDefiningOp(pos
);
94 predList
.emplace_back(parentPos
, builder
.getIsNotNull());
96 // Ensure that the operands match the corresponding results of the
98 Position
*resultPos
= nullptr;
99 if (std::is_same
<pdl::ResultOp
, decltype(op
)>::value
)
100 resultPos
= builder
.getResult(parentPos
, *index
);
102 resultPos
= builder
.getResultGroup(parentPos
, index
, isVariadic
);
103 predList
.emplace_back(resultPos
, builder
.getEqualTo(pos
));
105 // Collect the predicates of the parent operation.
106 getTreePredicates(predList
, op
.getParent(), builder
, inputs
,
107 (Position
*)parentPos
);
112 getTreePredicates(std::vector
<PositionalPredicate
> &predList
, Value val
,
113 PredicateBuilder
&builder
,
114 DenseMap
<Value
, Position
*> &inputs
, OperationPosition
*pos
,
115 std::optional
<unsigned> ignoreOperand
= std::nullopt
) {
116 assert(isa
<pdl::OperationType
>(val
.getType()) && "expected operation");
117 pdl::OperationOp op
= cast
<pdl::OperationOp
>(val
.getDefiningOp());
118 OperationPosition
*opPos
= cast
<OperationPosition
>(pos
);
120 // Ensure getDefiningOp returns a non-null operation.
121 if (!opPos
->isRoot())
122 predList
.emplace_back(pos
, builder
.getIsNotNull());
124 // Check that this is the correct root operation.
125 if (std::optional
<StringRef
> opName
= op
.getOpName())
126 predList
.emplace_back(pos
, builder
.getOperationName(*opName
));
128 // Check that the operation has the proper number of operands. If there are
129 // any variable length operands, we check a minimum instead of an exact count.
130 OperandRange operands
= op
.getOperandValues();
131 unsigned minOperands
= getNumNonRangeValues(operands
);
132 if (minOperands
!= operands
.size()) {
134 predList
.emplace_back(pos
, builder
.getOperandCountAtLeast(minOperands
));
136 predList
.emplace_back(pos
, builder
.getOperandCount(minOperands
));
139 // Check that the operation has the proper number of results. If there are
140 // any variable length results, we check a minimum instead of an exact count.
141 OperandRange types
= op
.getTypeValues();
142 unsigned minResults
= getNumNonRangeValues(types
);
143 if (minResults
== types
.size())
144 predList
.emplace_back(pos
, builder
.getResultCount(types
.size()));
146 predList
.emplace_back(pos
, builder
.getResultCountAtLeast(minResults
));
148 // Recurse into any attributes, operands, or results.
149 for (auto [attrName
, attr
] :
150 llvm::zip(op
.getAttributeValueNames(), op
.getAttributeValues())) {
152 predList
, attr
, builder
, inputs
,
153 builder
.getAttribute(opPos
, cast
<StringAttr
>(attrName
).getValue()));
156 // Process the operands and results of the operation. For all values up to
157 // the first variable length value, we use the concrete operand/result
158 // number. After that, we use the "group" given that we can't know the
159 // concrete indices until runtime. If there is only one variadic operand
160 // group, we treat it as all of the operands/results of the operation.
162 if (operands
.size() == 1 && isa
<pdl::RangeType
>(operands
[0].getType())) {
163 // Ignore the operands if we are performing an upward traversal (in that
164 // case, they have already been visited).
165 if (opPos
->isRoot() || opPos
->isOperandDefiningOp())
166 getTreePredicates(predList
, operands
.front(), builder
, inputs
,
167 builder
.getAllOperands(opPos
));
169 bool foundVariableLength
= false;
170 for (const auto &operandIt
: llvm::enumerate(operands
)) {
171 bool isVariadic
= isa
<pdl::RangeType
>(operandIt
.value().getType());
172 foundVariableLength
|= isVariadic
;
174 // Ignore the specified operand, usually because this position was
175 // visited in an upward traversal via an iterative choice.
176 if (ignoreOperand
&& *ignoreOperand
== operandIt
.index())
181 ? builder
.getOperandGroup(opPos
, operandIt
.index(), isVariadic
)
182 : builder
.getOperand(opPos
, operandIt
.index());
183 getTreePredicates(predList
, operandIt
.value(), builder
, inputs
, pos
);
187 if (types
.size() == 1 && isa
<pdl::RangeType
>(types
[0].getType())) {
188 getTreePredicates(predList
, types
.front(), builder
, inputs
,
189 builder
.getType(builder
.getAllResults(opPos
)));
193 bool foundVariableLength
= false;
194 for (auto [idx
, typeValue
] : llvm::enumerate(types
)) {
195 bool isVariadic
= isa
<pdl::RangeType
>(typeValue
.getType());
196 foundVariableLength
|= isVariadic
;
198 auto *resultPos
= foundVariableLength
199 ? builder
.getResultGroup(pos
, idx
, isVariadic
)
200 : builder
.getResult(pos
, idx
);
201 predList
.emplace_back(resultPos
, builder
.getIsNotNull());
202 getTreePredicates(predList
, typeValue
, builder
, inputs
,
203 builder
.getType(resultPos
));
207 static void getTreePredicates(std::vector
<PositionalPredicate
> &predList
,
208 Value val
, PredicateBuilder
&builder
,
209 DenseMap
<Value
, Position
*> &inputs
,
211 // Check for a constraint on a constant type.
212 if (pdl::TypeOp typeOp
= val
.getDefiningOp
<pdl::TypeOp
>()) {
213 if (Attribute type
= typeOp
.getConstantTypeAttr())
214 predList
.emplace_back(pos
, builder
.getTypeConstraint(type
));
215 } else if (pdl::TypesOp typeOp
= val
.getDefiningOp
<pdl::TypesOp
>()) {
216 if (Attribute typeAttr
= typeOp
.getConstantTypesAttr())
217 predList
.emplace_back(pos
, builder
.getTypeConstraint(typeAttr
));
221 /// Collect the tree predicates anchored at the given value.
222 static void getTreePredicates(std::vector
<PositionalPredicate
> &predList
,
223 Value val
, PredicateBuilder
&builder
,
224 DenseMap
<Value
, Position
*> &inputs
,
226 // Make sure this input value is accessible to the rewrite.
227 auto it
= inputs
.try_emplace(val
, pos
);
229 // If this is an input value that has been visited in the tree, add a
230 // constraint to ensure that both instances refer to the same value.
231 if (isa
<pdl::AttributeOp
, pdl::OperandOp
, pdl::OperandsOp
, pdl::OperationOp
,
232 pdl::TypeOp
>(val
.getDefiningOp())) {
233 auto minMaxPositions
=
234 std::minmax(pos
, it
.first
->second
, comparePosDepth
);
235 predList
.emplace_back(minMaxPositions
.second
,
236 builder
.getEqualTo(minMaxPositions
.first
));
241 TypeSwitch
<Position
*>(pos
)
242 .Case
<AttributePosition
, OperationPosition
, TypePosition
>([&](auto *pos
) {
243 getTreePredicates(predList
, val
, builder
, inputs
, pos
);
245 .Case
<OperandPosition
, OperandGroupPosition
>([&](auto *pos
) {
246 getOperandTreePredicates(predList
, val
, builder
, inputs
, pos
);
248 .Default([](auto *) { llvm_unreachable("unexpected position kind"); });
251 static void getAttributePredicates(pdl::AttributeOp op
,
252 std::vector
<PositionalPredicate
> &predList
,
253 PredicateBuilder
&builder
,
254 DenseMap
<Value
, Position
*> &inputs
) {
255 Position
*&attrPos
= inputs
[op
];
258 Attribute value
= op
.getValueAttr();
259 assert(value
&& "expected non-tree `pdl.attribute` to contain a value");
260 attrPos
= builder
.getAttributeLiteral(value
);
263 static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op
,
264 std::vector
<PositionalPredicate
> &predList
,
265 PredicateBuilder
&builder
,
266 DenseMap
<Value
, Position
*> &inputs
) {
267 OperandRange arguments
= op
.getArgs();
269 std::vector
<Position
*> allPositions
;
270 allPositions
.reserve(arguments
.size());
271 for (Value arg
: arguments
)
272 allPositions
.push_back(inputs
.lookup(arg
));
274 // Push the constraint to the furthest position.
275 Position
*pos
= *llvm::max_element(allPositions
, comparePosDepth
);
276 ResultRange results
= op
.getResults();
277 PredicateBuilder::Predicate pred
= builder
.getConstraint(
278 op
.getName(), allPositions
, SmallVector
<Type
>(results
.getTypes()),
281 // For each result register a position so it can be used later
282 for (auto [i
, result
] : llvm::enumerate(results
)) {
283 ConstraintQuestion
*q
= cast
<ConstraintQuestion
>(pred
.first
);
284 ConstraintPosition
*pos
= builder
.getConstraintPosition(q
, i
);
285 auto [it
, inserted
] = inputs
.try_emplace(result
, pos
);
286 // If this is an input value that has been visited in the tree, add a
287 // constraint to ensure that both instances refer to the same value.
289 Position
*first
= pos
;
290 Position
*second
= it
->second
;
291 if (comparePosDepth(second
, first
))
292 std::tie(second
, first
) = std::make_pair(first
, second
);
294 predList
.emplace_back(second
, builder
.getEqualTo(first
));
297 predList
.emplace_back(pos
, pred
);
300 static void getResultPredicates(pdl::ResultOp op
,
301 std::vector
<PositionalPredicate
> &predList
,
302 PredicateBuilder
&builder
,
303 DenseMap
<Value
, Position
*> &inputs
) {
304 Position
*&resultPos
= inputs
[op
];
308 // Ensure that the result isn't null.
309 auto *parentPos
= cast
<OperationPosition
>(inputs
.lookup(op
.getParent()));
310 resultPos
= builder
.getResult(parentPos
, op
.getIndex());
311 predList
.emplace_back(resultPos
, builder
.getIsNotNull());
314 static void getResultPredicates(pdl::ResultsOp op
,
315 std::vector
<PositionalPredicate
> &predList
,
316 PredicateBuilder
&builder
,
317 DenseMap
<Value
, Position
*> &inputs
) {
318 Position
*&resultPos
= inputs
[op
];
322 // Ensure that the result isn't null if the result has an index.
323 auto *parentPos
= cast
<OperationPosition
>(inputs
.lookup(op
.getParent()));
324 bool isVariadic
= isa
<pdl::RangeType
>(op
.getType());
325 std::optional
<unsigned> index
= op
.getIndex();
326 resultPos
= builder
.getResultGroup(parentPos
, index
, isVariadic
);
328 predList
.emplace_back(resultPos
, builder
.getIsNotNull());
331 static void getTypePredicates(Value typeValue
,
332 function_ref
<Attribute()> typeAttrFn
,
333 PredicateBuilder
&builder
,
334 DenseMap
<Value
, Position
*> &inputs
) {
335 Position
*&typePos
= inputs
[typeValue
];
338 Attribute typeAttr
= typeAttrFn();
340 "expected non-tree `pdl.type`/`pdl.types` to contain a value");
341 typePos
= builder
.getTypeLiteral(typeAttr
);
344 /// Collect all of the predicates that cannot be determined via walking the
346 static void getNonTreePredicates(pdl::PatternOp pattern
,
347 std::vector
<PositionalPredicate
> &predList
,
348 PredicateBuilder
&builder
,
349 DenseMap
<Value
, Position
*> &inputs
) {
350 for (Operation
&op
: pattern
.getBodyRegion().getOps()) {
351 TypeSwitch
<Operation
*>(&op
)
352 .Case([&](pdl::AttributeOp attrOp
) {
353 getAttributePredicates(attrOp
, predList
, builder
, inputs
);
355 .Case
<pdl::ApplyNativeConstraintOp
>([&](auto constraintOp
) {
356 getConstraintPredicates(constraintOp
, predList
, builder
, inputs
);
358 .Case
<pdl::ResultOp
, pdl::ResultsOp
>([&](auto resultOp
) {
359 getResultPredicates(resultOp
, predList
, builder
, inputs
);
361 .Case([&](pdl::TypeOp typeOp
) {
363 typeOp
, [&] { return typeOp
.getConstantTypeAttr(); }, builder
,
366 .Case([&](pdl::TypesOp typeOp
) {
368 typeOp
, [&] { return typeOp
.getConstantTypesAttr(); }, builder
,
376 /// An op accepting a value at an optional index.
379 std::optional
<unsigned> index
;
382 /// The parent and operand index of each operation for each root, stored
383 /// as a nested map [root][operation].
384 using ParentMaps
= DenseMap
<Value
, DenseMap
<Value
, OpIndex
>>;
388 /// Given a pattern, determines the set of roots present in this pattern.
389 /// These are the operations whose results are not consumed by other operations.
390 static SmallVector
<Value
> detectRoots(pdl::PatternOp pattern
) {
391 // First, collect all the operations that are used as operands
392 // to other operations. These are not roots by default.
393 DenseSet
<Value
> used
;
394 for (auto operationOp
: pattern
.getBodyRegion().getOps
<pdl::OperationOp
>()) {
395 for (Value operand
: operationOp
.getOperandValues())
396 TypeSwitch
<Operation
*>(operand
.getDefiningOp())
397 .Case
<pdl::ResultOp
, pdl::ResultsOp
>(
398 [&used
](auto resultOp
) { used
.insert(resultOp
.getParent()); });
401 // Remove the specified root from the use set, so that we can
402 // always select it as a root, even if it is used by other operations.
403 if (Value root
= pattern
.getRewriter().getRoot())
406 // Finally, collect all the unused operations.
407 SmallVector
<Value
> roots
;
408 for (Value operationOp
: pattern
.getBodyRegion().getOps
<pdl::OperationOp
>())
409 if (!used
.contains(operationOp
))
410 roots
.push_back(operationOp
);
415 /// Given a list of candidate roots, builds the cost graph for connecting them.
416 /// The graph is formed by traversing the DAG of operations starting from each
417 /// root and marking the depth of each connector value (operand). Then we join
418 /// the candidate roots based on the common connector values, taking the one
419 /// with the minimum depth. Along the way, we compute, for each candidate root,
420 /// a mapping from each operation (in the DAG underneath this root) to its
421 /// parent operation and the corresponding operand index.
422 static void buildCostGraph(ArrayRef
<Value
> roots
, RootOrderingGraph
&graph
,
423 ParentMaps
&parentMaps
) {
425 // The entry of a queue. The entry consists of the following items:
426 // * the value in the DAG underneath the root;
427 // * the parent of the value;
428 // * the operand index of the value in its parent;
429 // * the depth of the visited value.
431 Entry(Value value
, Value parent
, std::optional
<unsigned> index
,
433 : value(value
), parent(parent
), index(index
), depth(depth
) {}
437 std::optional
<unsigned> index
;
441 // A root of a value and its depth (distance from root to the value).
447 // Map from candidate connector values to their roots and depths. Using a
448 // small vector with 1 entry because most values belong to a single root.
449 llvm::MapVector
<Value
, SmallVector
<RootDepth
, 1>> connectorsRootsDepths
;
451 // Perform a breadth-first traversal of the op DAG rooted at each root.
452 for (Value root
: roots
) {
453 // The queue of visited values. A value may be present multiple times in
454 // the queue, for multiple parents. We only accept the first occurrence,
455 // which is guaranteed to have the lowest depth.
456 std::queue
<Entry
> toVisit
;
457 toVisit
.emplace(root
, Value(), 0, 0);
459 // The map from value to its parent for the current root.
460 DenseMap
<Value
, OpIndex
> &parentMap
= parentMaps
[root
];
462 while (!toVisit
.empty()) {
463 Entry entry
= toVisit
.front();
465 // Skip if already visited.
466 if (!parentMap
.insert({entry
.value
, {entry
.parent
, entry
.index
}}).second
)
469 // Mark the root and depth of the value.
470 connectorsRootsDepths
[entry
.value
].push_back({root
, entry
.depth
});
472 // Traverse the operands of an operation and result ops.
473 // We intentionally do not traverse attributes and types, because those
474 // are expensive to join on.
475 TypeSwitch
<Operation
*>(entry
.value
.getDefiningOp())
476 .Case
<pdl::OperationOp
>([&](auto operationOp
) {
477 OperandRange operands
= operationOp
.getOperandValues();
478 // Special case when we pass all the operands in one range.
479 // For those, the index is empty.
480 if (operands
.size() == 1 &&
481 isa
<pdl::RangeType
>(operands
[0].getType())) {
482 toVisit
.emplace(operands
[0], entry
.value
, std::nullopt
,
487 // Default case: visit all the operands.
489 llvm::enumerate(operationOp
.getOperandValues()))
490 toVisit
.emplace(p
.value(), entry
.value
, p
.index(),
493 .Case
<pdl::ResultOp
, pdl::ResultsOp
>([&](auto resultOp
) {
494 toVisit
.emplace(resultOp
.getParent(), entry
.value
,
495 resultOp
.getIndex(), entry
.depth
);
500 // Now build the cost graph.
501 // This is simply a minimum over all depths for the target root.
503 for (const auto &connectorRootsDepths
: connectorsRootsDepths
) {
504 Value value
= connectorRootsDepths
.first
;
505 ArrayRef
<RootDepth
> rootsDepths
= connectorRootsDepths
.second
;
506 // If there is only one root for this value, this will not trigger
507 // any edges in the cost graph (a perf optimization).
508 if (rootsDepths
.size() == 1)
511 for (const RootDepth
&p
: rootsDepths
) {
512 for (const RootDepth
&q
: rootsDepths
) {
515 // Insert or retrieve the property of edge from p to q.
516 RootOrderingEntry
&entry
= graph
[q
.root
][p
.root
];
517 if (!entry
.connector
/* new edge */ || entry
.cost
.first
> q
.depth
) {
518 if (!entry
.connector
)
519 entry
.cost
.second
= nextID
++;
520 entry
.cost
.first
= q
.depth
;
521 entry
.connector
= value
;
527 assert((llvm::hasSingleElement(roots
) || graph
.size() == roots
.size()) &&
528 "the pattern contains a candidate root disconnected from the others");
531 /// Returns true if the operand at the given index needs to be queried using an
532 /// operand group, i.e., if it is variadic itself or follows a variadic operand.
533 static bool useOperandGroup(pdl::OperationOp op
, unsigned index
) {
534 OperandRange operands
= op
.getOperandValues();
535 assert(index
< operands
.size() && "operand index out of range");
536 for (unsigned i
= 0; i
<= index
; ++i
)
537 if (isa
<pdl::RangeType
>(operands
[i
].getType()))
542 /// Visit a node during upward traversal.
543 static void visitUpward(std::vector
<PositionalPredicate
> &predList
,
544 OpIndex opIndex
, PredicateBuilder
&builder
,
545 DenseMap
<Value
, Position
*> &valueToPosition
,
546 Position
*&pos
, unsigned rootID
) {
547 Value value
= opIndex
.parent
;
548 TypeSwitch
<Operation
*>(value
.getDefiningOp())
549 .Case
<pdl::OperationOp
>([&](auto operationOp
) {
550 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value
<< "\n");
552 // Get users and iterate over them.
553 Position
*usersPos
= builder
.getUsers(pos
, /*useRepresentative=*/true);
554 Position
*foreachPos
= builder
.getForEach(usersPos
, rootID
);
555 OperationPosition
*opPos
= builder
.getPassthroughOp(foreachPos
);
557 // Compare the operand(s) of the user against the input value(s).
558 Position
*operandPos
;
559 if (!opIndex
.index
) {
560 // We are querying all the operands of the operation.
561 operandPos
= builder
.getAllOperands(opPos
);
562 } else if (useOperandGroup(operationOp
, *opIndex
.index
)) {
563 // We are querying an operand group.
564 Type type
= operationOp
.getOperandValues()[*opIndex
.index
].getType();
565 bool variadic
= isa
<pdl::RangeType
>(type
);
566 operandPos
= builder
.getOperandGroup(opPos
, opIndex
.index
, variadic
);
568 // We are querying an individual operand.
569 operandPos
= builder
.getOperand(opPos
, *opIndex
.index
);
571 predList
.emplace_back(operandPos
, builder
.getEqualTo(pos
));
573 // Guard against duplicate upward visits. These are not possible,
574 // because if this value was already visited, it would have been
575 // cheaper to start the traversal at this value rather than at the
576 // `connector`, violating the optimality of our spanning tree.
577 bool inserted
= valueToPosition
.try_emplace(value
, opPos
).second
;
579 assert(inserted
&& "duplicate upward visit");
581 // Obtain the tree predicates at the current value.
582 getTreePredicates(predList
, value
, builder
, valueToPosition
, opPos
,
585 // Update the position
588 .Case
<pdl::ResultOp
>([&](auto resultOp
) {
589 // Traverse up an individual result.
590 auto *opPos
= dyn_cast
<OperationPosition
>(pos
);
591 assert(opPos
&& "operations and results must be interleaved");
592 pos
= builder
.getResult(opPos
, *opIndex
.index
);
594 // Insert the result position in case we have not visited it yet.
595 valueToPosition
.try_emplace(value
, pos
);
597 .Case
<pdl::ResultsOp
>([&](auto resultOp
) {
598 // Traverse up a group of results.
599 auto *opPos
= dyn_cast
<OperationPosition
>(pos
);
600 assert(opPos
&& "operations and results must be interleaved");
601 bool isVariadic
= isa
<pdl::RangeType
>(value
.getType());
603 pos
= builder
.getResultGroup(opPos
, opIndex
.index
, isVariadic
);
605 pos
= builder
.getAllResults(opPos
);
607 // Insert the result position in case we have not visited it yet.
608 valueToPosition
.try_emplace(value
, pos
);
612 /// Given a pattern operation, build the set of matcher predicates necessary to
613 /// match this pattern.
614 static Value
buildPredicateList(pdl::PatternOp pattern
,
615 PredicateBuilder
&builder
,
616 std::vector
<PositionalPredicate
> &predList
,
617 DenseMap
<Value
, Position
*> &valueToPosition
) {
618 SmallVector
<Value
> roots
= detectRoots(pattern
);
620 // Build the root ordering graph and compute the parent maps.
621 RootOrderingGraph graph
;
622 ParentMaps parentMaps
;
623 buildCostGraph(roots
, graph
, parentMaps
);
625 llvm::dbgs() << "Graph:\n";
626 for (auto &target
: graph
) {
627 llvm::dbgs() << " * " << target
.first
.getLoc() << " " << target
.first
629 for (auto &source
: target
.second
) {
630 RootOrderingEntry
&entry
= source
.second
;
631 llvm::dbgs() << " <- " << source
.first
<< ": " << entry
.cost
.first
632 << ":" << entry
.cost
.second
<< " via "
633 << entry
.connector
.getLoc() << "\n";
638 // Solve the optimal branching problem for each candidate root, or use the
640 Value bestRoot
= pattern
.getRewriter().getRoot();
641 OptimalBranching::EdgeList bestEdges
;
643 unsigned bestCost
= 0;
644 LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n");
645 for (Value root
: roots
) {
646 OptimalBranching
solver(graph
, root
);
647 unsigned cost
= solver
.solve();
648 LLVM_DEBUG(llvm::dbgs() << " * " << root
<< ": " << cost
<< "\n");
649 if (!bestRoot
|| bestCost
> cost
) {
652 bestEdges
= solver
.preOrderTraversal(roots
);
656 OptimalBranching
solver(graph
, bestRoot
);
658 bestEdges
= solver
.preOrderTraversal(roots
);
661 // Print the best solution.
663 llvm::dbgs() << "Best tree:\n";
664 for (const std::pair
<Value
, Value
> &edge
: bestEdges
) {
665 llvm::dbgs() << " * " << edge
.first
;
667 llvm::dbgs() << " <- " << edge
.second
;
668 llvm::dbgs() << "\n";
672 LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n");
673 LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot
<< "\n");
675 // The best root is the starting point for the traversal. Get the tree
676 // predicates for the DAG rooted at bestRoot.
677 getTreePredicates(predList
, bestRoot
, builder
, valueToPosition
,
680 // Traverse the selected optimal branching. For all edges in order, traverse
681 // up starting from the connector, until the candidate root is reached, and
682 // call getTreePredicates at every node along the way.
683 for (const auto &it
: llvm::enumerate(bestEdges
)) {
684 Value target
= it
.value().first
;
685 Value source
= it
.value().second
;
687 // Check if we already visited the target root. This happens in two cases:
688 // 1) the initial root (bestRoot);
689 // 2) a root that is dominated by (contained in the subtree rooted at) an
690 // already visited root.
691 if (valueToPosition
.count(target
))
694 // Determine the connector.
695 Value connector
= graph
[target
][source
].connector
;
696 assert(connector
&& "invalid edge");
697 LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector
.getLoc() << "\n");
698 DenseMap
<Value
, OpIndex
> parentMap
= parentMaps
.lookup(target
);
699 Position
*pos
= valueToPosition
.lookup(connector
);
700 assert(pos
&& "connector has not been traversed yet");
702 // Traverse from the connector upwards towards the target root.
703 for (Value value
= connector
; value
!= target
;) {
704 OpIndex opIndex
= parentMap
.lookup(value
);
705 assert(opIndex
.parent
&& "missing parent");
706 visitUpward(predList
, opIndex
, builder
, valueToPosition
, pos
, it
.index());
707 value
= opIndex
.parent
;
711 getNonTreePredicates(pattern
, predList
, builder
, valueToPosition
);
716 //===----------------------------------------------------------------------===//
717 // Pattern Predicate Tree Merging
718 //===----------------------------------------------------------------------===//
722 /// This class represents a specific predicate applied to a position, and
723 /// provides hashing and ordering operators. This class allows for computing a
724 /// frequence sum and ordering predicates based on a cost model.
725 struct OrderedPredicate
{
726 OrderedPredicate(const std::pair
<Position
*, Qualifier
*> &ip
)
727 : position(ip
.first
), question(ip
.second
) {}
728 OrderedPredicate(const PositionalPredicate
&ip
)
729 : position(ip
.position
), question(ip
.question
) {}
731 /// The position this predicate is applied to.
734 /// The question that is applied by this predicate onto the position.
737 /// The first and second order benefit sums.
738 /// The primary sum is the number of occurrences of this predicate among all
740 unsigned primary
= 0;
741 /// The secondary sum is a squared summation of the primary sum of all of the
742 /// predicates within each pattern that contains this predicate. This allows
743 /// for favoring predicates that are more commonly shared within a pattern, as
744 /// opposed to those shared across patterns.
745 unsigned secondary
= 0;
747 /// The tie breaking ID, used to preserve a deterministic (insertion) order
748 /// among all the predicates with the same priority, depth, and position /
749 /// predicate dependency.
752 /// A map between a pattern operation and the answer to the predicate question
753 /// within that pattern.
754 DenseMap
<Operation
*, Qualifier
*> patternToAnswer
;
756 /// Returns true if this predicate is ordered before `rhs`, based on the cost
758 bool operator<(const OrderedPredicate
&rhs
) const {
760 // * higher first and secondary order sums
762 // * lower position dependency
763 // * lower predicate dependency
764 // * lower tie breaking ID
765 auto *rhsPos
= rhs
.position
;
766 return std::make_tuple(primary
, secondary
, rhsPos
->getOperationDepth(),
767 rhsPos
->getKind(), rhs
.question
->getKind(), rhs
.id
) >
768 std::make_tuple(rhs
.primary
, rhs
.secondary
,
769 position
->getOperationDepth(), position
->getKind(),
770 question
->getKind(), id
);
774 /// A DenseMapInfo for OrderedPredicate based solely on the position and
776 struct OrderedPredicateDenseInfo
{
777 using Base
= DenseMapInfo
<std::pair
<Position
*, Qualifier
*>>;
779 static OrderedPredicate
getEmptyKey() { return Base::getEmptyKey(); }
780 static OrderedPredicate
getTombstoneKey() { return Base::getTombstoneKey(); }
781 static bool isEqual(const OrderedPredicate
&lhs
,
782 const OrderedPredicate
&rhs
) {
783 return lhs
.position
== rhs
.position
&& lhs
.question
== rhs
.question
;
785 static unsigned getHashValue(const OrderedPredicate
&p
) {
786 return llvm::hash_combine(p
.position
, p
.question
);
790 /// This class wraps a set of ordered predicates that are used within a specific
791 /// pattern operation.
792 struct OrderedPredicateList
{
793 OrderedPredicateList(pdl::PatternOp pattern
, Value root
)
794 : pattern(pattern
), root(root
) {}
796 pdl::PatternOp pattern
;
798 DenseSet
<OrderedPredicate
*> predicates
;
802 /// Returns true if the given matcher refers to the same predicate as the given
803 /// ordered predicate. This means that the position and questions of the two
805 static bool isSamePredicate(MatcherNode
*node
, OrderedPredicate
*predicate
) {
806 return node
->getPosition() == predicate
->position
&&
807 node
->getQuestion() == predicate
->question
;
810 /// Get or insert a child matcher for the given parent switch node, given a
811 /// predicate and parent pattern.
812 std::unique_ptr
<MatcherNode
> &getOrCreateChild(SwitchNode
*node
,
813 OrderedPredicate
*predicate
,
814 pdl::PatternOp pattern
) {
815 assert(isSamePredicate(node
, predicate
) &&
816 "expected matcher to equal the given predicate");
818 auto it
= predicate
->patternToAnswer
.find(pattern
);
819 assert(it
!= predicate
->patternToAnswer
.end() &&
820 "expected pattern to exist in predicate");
821 return node
->getChildren()[it
->second
];
824 /// Build the matcher CFG by "pushing" patterns through by sorted predicate
825 /// order. A pattern will traverse as far as possible using common predicates
826 /// and then either diverge from the CFG or reach the end of a branch and start
827 /// creating new nodes.
828 static void propagatePattern(std::unique_ptr
<MatcherNode
> &node
,
829 OrderedPredicateList
&list
,
830 std::vector
<OrderedPredicate
*>::iterator current
,
831 std::vector
<OrderedPredicate
*>::iterator end
) {
832 if (current
== end
) {
833 // We've hit the end of a pattern, so create a successful result node.
835 std::make_unique
<SuccessNode
>(list
.pattern
, list
.root
, std::move(node
));
837 // If the pattern doesn't contain this predicate, ignore it.
838 } else if (!list
.predicates
.contains(*current
)) {
839 propagatePattern(node
, list
, std::next(current
), end
);
841 // If the current matcher node is invalid, create a new one for this
842 // position and continue propagation.
844 // Create a new node at this position and continue
845 node
= std::make_unique
<SwitchNode
>((*current
)->position
,
846 (*current
)->question
);
848 getOrCreateChild(cast
<SwitchNode
>(&*node
), *current
, list
.pattern
),
849 list
, std::next(current
), end
);
851 // If the matcher has already been created, and it is for this predicate we
852 // continue propagation to the child.
853 } else if (isSamePredicate(node
.get(), *current
)) {
855 getOrCreateChild(cast
<SwitchNode
>(&*node
), *current
, list
.pattern
),
856 list
, std::next(current
), end
);
858 // If the matcher doesn't match the current predicate, insert a branch as
859 // the common set of matchers has diverged.
861 propagatePattern(node
->getFailureNode(), list
, current
, end
);
865 /// Fold any switch nodes nested under `node` to boolean nodes when possible.
866 /// `node` is updated in-place if it is a switch.
867 static void foldSwitchToBool(std::unique_ptr
<MatcherNode
> &node
) {
871 if (SwitchNode
*switchNode
= dyn_cast
<SwitchNode
>(&*node
)) {
872 SwitchNode::ChildMapT
&children
= switchNode
->getChildren();
873 for (auto &it
: children
)
874 foldSwitchToBool(it
.second
);
876 // If the node only contains one child, collapse it into a boolean predicate
878 if (children
.size() == 1) {
879 auto *childIt
= children
.begin();
880 node
= std::make_unique
<BoolNode
>(
881 node
->getPosition(), node
->getQuestion(), childIt
->first
,
882 std::move(childIt
->second
), std::move(node
->getFailureNode()));
884 } else if (BoolNode
*boolNode
= dyn_cast
<BoolNode
>(&*node
)) {
885 foldSwitchToBool(boolNode
->getSuccessNode());
888 foldSwitchToBool(node
->getFailureNode());
891 /// Insert an exit node at the end of the failure path of the `root`.
892 static void insertExitNode(std::unique_ptr
<MatcherNode
> *root
) {
894 root
= &(*root
)->getFailureNode();
895 *root
= std::make_unique
<ExitNode
>();
898 /// Sorts the range begin/end with the partial order given by cmp.
899 template <typename Iterator
, typename Compare
>
900 static void stableTopologicalSort(Iterator begin
, Iterator end
, Compare cmp
) {
901 while (begin
!= end
) {
902 // Cannot compute sortBeforeOthers in the predicate of stable_partition
903 // because stable_partition will not keep the [begin, end) range intact
905 llvm::SmallPtrSet
<typename
Iterator::value_type
, 16> sortBeforeOthers
;
906 for (auto i
= begin
; i
!= end
; ++i
) {
907 if (std::none_of(begin
, end
, [&](auto const &b
) { return cmp(b
, *i
); }))
908 sortBeforeOthers
.insert(*i
);
911 auto const next
= std::stable_partition(begin
, end
, [&](auto const &a
) {
912 return sortBeforeOthers
.contains(a
);
914 assert(next
!= begin
&& "not a partial ordering");
919 /// Returns true if 'b' depends on a result of 'a'.
920 static bool dependsOn(OrderedPredicate
*a
, OrderedPredicate
*b
) {
921 auto *cqa
= dyn_cast
<ConstraintQuestion
>(a
->question
);
925 auto positionDependsOnA
= [&](Position
*p
) {
926 auto *cp
= dyn_cast
<ConstraintPosition
>(p
);
927 return cp
&& cp
->getQuestion() == cqa
;
930 if (auto *cqb
= dyn_cast
<ConstraintQuestion
>(b
->question
)) {
931 // Does any argument of b use a?
932 return llvm::any_of(cqb
->getArgs(), positionDependsOnA
);
934 if (auto *equalTo
= dyn_cast
<EqualToQuestion
>(b
->question
)) {
935 return positionDependsOnA(b
->position
) ||
936 positionDependsOnA(equalTo
->getValue());
938 return positionDependsOnA(b
->position
);
941 /// Given a module containing PDL pattern operations, generate a matcher tree
942 /// using the patterns within the given module and return the root matcher node.
943 std::unique_ptr
<MatcherNode
>
944 MatcherNode::generateMatcherTree(ModuleOp module
, PredicateBuilder
&builder
,
945 DenseMap
<Value
, Position
*> &valueToPosition
) {
946 // The set of predicates contained within the pattern operations of the
948 struct PatternPredicates
{
949 PatternPredicates(pdl::PatternOp pattern
, Value root
,
950 std::vector
<PositionalPredicate
> predicates
)
951 : pattern(pattern
), root(root
), predicates(std::move(predicates
)) {}
954 pdl::PatternOp pattern
;
956 /// A root of the pattern chosen among the candidate roots in pdl.rewrite.
959 /// The extracted predicates for this pattern and root.
960 std::vector
<PositionalPredicate
> predicates
;
963 SmallVector
<PatternPredicates
, 16> patternsAndPredicates
;
964 for (pdl::PatternOp pattern
: module
.getOps
<pdl::PatternOp
>()) {
965 std::vector
<PositionalPredicate
> predicateList
;
967 buildPredicateList(pattern
, builder
, predicateList
, valueToPosition
);
968 patternsAndPredicates
.emplace_back(pattern
, root
, std::move(predicateList
));
971 // Associate a pattern result with each unique predicate.
972 DenseSet
<OrderedPredicate
, OrderedPredicateDenseInfo
> uniqued
;
973 for (auto &patternAndPredList
: patternsAndPredicates
) {
974 for (auto &predicate
: patternAndPredList
.predicates
) {
975 auto it
= uniqued
.insert(predicate
);
976 it
.first
->patternToAnswer
.try_emplace(patternAndPredList
.pattern
,
978 // Mark the insertion order (0-based indexing).
980 it
.first
->id
= uniqued
.size() - 1;
984 // Associate each pattern to a set of its ordered predicates for later lookup.
985 std::vector
<OrderedPredicateList
> lists
;
986 lists
.reserve(patternsAndPredicates
.size());
987 for (auto &patternAndPredList
: patternsAndPredicates
) {
988 OrderedPredicateList
list(patternAndPredList
.pattern
,
989 patternAndPredList
.root
);
990 for (auto &predicate
: patternAndPredList
.predicates
) {
991 OrderedPredicate
*orderedPredicate
= &*uniqued
.find(predicate
);
992 list
.predicates
.insert(orderedPredicate
);
994 // Increment the primary sum for each reference to a particular predicate.
995 ++orderedPredicate
->primary
;
997 lists
.push_back(std::move(list
));
1000 // For a particular pattern, get the total primary sum and add it to the
1001 // secondary sum of each predicate. Square the primary sums to emphasize
1002 // shared predicates within rather than across patterns.
1003 for (auto &list
: lists
) {
1005 for (auto *predicate
: list
.predicates
)
1006 total
+= predicate
->primary
* predicate
->primary
;
1007 for (auto *predicate
: list
.predicates
)
1008 predicate
->secondary
+= total
;
1011 // Sort the set of predicates now that the cost primary and secondary sums
1012 // have been computed.
1013 std::vector
<OrderedPredicate
*> ordered
;
1014 ordered
.reserve(uniqued
.size());
1015 for (auto &ip
: uniqued
)
1016 ordered
.push_back(&ip
);
1017 llvm::sort(ordered
, [](OrderedPredicate
*lhs
, OrderedPredicate
*rhs
) {
1021 // Mostly keep the now established order, but also ensure that
1022 // ConstraintQuestions come after the results they use.
1023 stableTopologicalSort(ordered
.begin(), ordered
.end(), dependsOn
);
1025 // Build the matchers for each of the pattern predicate lists.
1026 std::unique_ptr
<MatcherNode
> root
;
1027 for (OrderedPredicateList
&list
: lists
)
1028 propagatePattern(root
, list
, ordered
.begin(), ordered
.end());
1030 // Collapse the graph and insert the exit node.
1031 foldSwitchToBool(root
);
1032 insertExitNode(&root
);
1036 //===----------------------------------------------------------------------===//
1038 //===----------------------------------------------------------------------===//
1040 MatcherNode::MatcherNode(TypeID matcherTypeID
, Position
*p
, Qualifier
*q
,
1041 std::unique_ptr
<MatcherNode
> failureNode
)
1042 : position(p
), question(q
), failureNode(std::move(failureNode
)),
1043 matcherTypeID(matcherTypeID
) {}
1045 //===----------------------------------------------------------------------===//
1047 //===----------------------------------------------------------------------===//
1049 BoolNode::BoolNode(Position
*position
, Qualifier
*question
, Qualifier
*answer
,
1050 std::unique_ptr
<MatcherNode
> successNode
,
1051 std::unique_ptr
<MatcherNode
> failureNode
)
1052 : MatcherNode(TypeID::get
<BoolNode
>(), position
, question
,
1053 std::move(failureNode
)),
1054 answer(answer
), successNode(std::move(successNode
)) {}
1056 //===----------------------------------------------------------------------===//
1058 //===----------------------------------------------------------------------===//
1060 SuccessNode::SuccessNode(pdl::PatternOp pattern
, Value root
,
1061 std::unique_ptr
<MatcherNode
> failureNode
)
1062 : MatcherNode(TypeID::get
<SuccessNode
>(), /*position=*/nullptr,
1063 /*question=*/nullptr, std::move(failureNode
)),
1064 pattern(pattern
), root(root
) {}
1066 //===----------------------------------------------------------------------===//
1068 //===----------------------------------------------------------------------===//
1070 SwitchNode::SwitchNode(Position
*position
, Qualifier
*question
)
1071 : MatcherNode(TypeID::get
<SwitchNode
>(), position
, question
) {}