1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
29 using namespace mlir::pdl_to_pdl_interp
;
31 //===----------------------------------------------------------------------===//
33 //===----------------------------------------------------------------------===//
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering
{
40 PatternLowering(pdl_interp::FuncOp matcherFunc
, ModuleOp rewriterModule
,
41 DenseMap
<Operation
*, PDLPatternConfigSet
*> *configMap
);
43 /// Generate code for matching and rewriting based on the pattern operations
44 /// within the module.
45 void lower(ModuleOp module
);
48 using ValueMap
= llvm::ScopedHashTable
<Position
*, Value
>;
49 using ValueMapScope
= llvm::ScopedHashTableScope
<Position
*, Value
>;
51 /// Generate interpreter operations for the tree rooted at the given matcher
52 /// node, in the specified region.
53 Block
*generateMatcher(MatcherNode
&node
, Region
®ion
,
54 Block
*block
= nullptr);
56 /// Get or create an access to the provided positional value in the current
57 /// block. This operation may mutate the provided block pointer if nested
58 /// regions (i.e., pdl_interp.iterate) are required.
59 Value
getValueAt(Block
*¤tBlock
, Position
*pos
);
61 /// Create the interpreter predicate operations. This operation may mutate the
62 /// provided current block pointer if nested regions (iterates) are required.
63 void generate(BoolNode
*boolNode
, Block
*¤tBlock
, Value val
);
65 /// Create the interpreter switch / predicate operations, with several case
66 /// destinations. This operation never mutates the provided current block
67 /// pointer, because the switch operation does not need Values beyond `val`.
68 void generate(SwitchNode
*switchNode
, Block
*currentBlock
, Value val
);
70 /// Create the interpreter operations to record a successful pattern match
71 /// using the contained root operation. This operation may mutate the current
72 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73 void generate(SuccessNode
*successNode
, Block
*¤tBlock
);
75 /// Generate a rewriter function for the given pattern operation, and returns
76 /// a reference to that function.
77 SymbolRefAttr
generateRewriter(pdl::PatternOp pattern
,
78 SmallVectorImpl
<Position
*> &usedMatchValues
);
80 /// Generate the rewriter code for the given operation.
81 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp
,
82 DenseMap
<Value
, Value
> &rewriteValues
,
83 function_ref
<Value(Value
)> mapRewriteValue
);
84 void generateRewriter(pdl::AttributeOp attrOp
,
85 DenseMap
<Value
, Value
> &rewriteValues
,
86 function_ref
<Value(Value
)> mapRewriteValue
);
87 void generateRewriter(pdl::EraseOp eraseOp
,
88 DenseMap
<Value
, Value
> &rewriteValues
,
89 function_ref
<Value(Value
)> mapRewriteValue
);
90 void generateRewriter(pdl::OperationOp operationOp
,
91 DenseMap
<Value
, Value
> &rewriteValues
,
92 function_ref
<Value(Value
)> mapRewriteValue
);
93 void generateRewriter(pdl::RangeOp rangeOp
,
94 DenseMap
<Value
, Value
> &rewriteValues
,
95 function_ref
<Value(Value
)> mapRewriteValue
);
96 void generateRewriter(pdl::ReplaceOp replaceOp
,
97 DenseMap
<Value
, Value
> &rewriteValues
,
98 function_ref
<Value(Value
)> mapRewriteValue
);
99 void generateRewriter(pdl::ResultOp resultOp
,
100 DenseMap
<Value
, Value
> &rewriteValues
,
101 function_ref
<Value(Value
)> mapRewriteValue
);
102 void generateRewriter(pdl::ResultsOp resultOp
,
103 DenseMap
<Value
, Value
> &rewriteValues
,
104 function_ref
<Value(Value
)> mapRewriteValue
);
105 void generateRewriter(pdl::TypeOp typeOp
,
106 DenseMap
<Value
, Value
> &rewriteValues
,
107 function_ref
<Value(Value
)> mapRewriteValue
);
108 void generateRewriter(pdl::TypesOp typeOp
,
109 DenseMap
<Value
, Value
> &rewriteValues
,
110 function_ref
<Value(Value
)> mapRewriteValue
);
112 /// Generate the values used for resolving the result types of an operation
113 /// created within a dag rewriter region. If the result types of the operation
114 /// should be inferred, `hasInferredResultTypes` is set to true.
115 void generateOperationResultTypeRewriter(
116 pdl::OperationOp op
, function_ref
<Value(Value
)> mapRewriteValue
,
117 SmallVectorImpl
<Value
> &types
, DenseMap
<Value
, Value
> &rewriteValues
,
118 bool &hasInferredResultTypes
);
120 /// A builder to use when generating interpreter operations.
123 /// The matcher function used for all match related logic within PDL patterns.
124 pdl_interp::FuncOp matcherFunc
;
126 /// The rewriter module containing the all rewrite related logic within PDL
128 ModuleOp rewriterModule
;
130 /// The symbol table of the rewriter module used for insertion.
131 SymbolTable rewriterSymbolTable
;
133 /// A scoped map connecting a position with the corresponding interpreter
137 /// A stack of blocks used as the failure destination for matcher nodes that
138 /// don't have an explicit failure path.
139 SmallVector
<Block
*, 8> failureBlockStack
;
141 /// A mapping between values defined in a pattern match, and the corresponding
142 /// positional value.
143 DenseMap
<Value
, Position
*> valueToPosition
;
145 /// The set of operation values whose location will be used for newly
146 /// generated operations.
147 SetVector
<Value
> locOps
;
149 /// A mapping between pattern operations and the corresponding configuration
151 DenseMap
<Operation
*, PDLPatternConfigSet
*> *configMap
;
153 /// A mapping from a constraint question to the ApplyConstraintOp
154 /// that implements it.
155 DenseMap
<ConstraintQuestion
*, pdl_interp::ApplyConstraintOp
> constraintOpMap
;
159 PatternLowering::PatternLowering(
160 pdl_interp::FuncOp matcherFunc
, ModuleOp rewriterModule
,
161 DenseMap
<Operation
*, PDLPatternConfigSet
*> *configMap
)
162 : builder(matcherFunc
.getContext()), matcherFunc(matcherFunc
),
163 rewriterModule(rewriterModule
), rewriterSymbolTable(rewriterModule
),
164 configMap(configMap
) {}
166 void PatternLowering::lower(ModuleOp module
) {
167 PredicateUniquer predicateUniquer
;
168 PredicateBuilder
predicateBuilder(predicateUniquer
, module
.getContext());
170 // Define top-level scope for the arguments to the matcher function.
171 ValueMapScope
topLevelValueScope(values
);
173 // Insert the root operation, i.e. argument to the matcher, at the root
175 Block
*matcherEntryBlock
= &matcherFunc
.front();
176 values
.insert(predicateBuilder
.getRoot(), matcherEntryBlock
->getArgument(0));
178 // Generate a root matcher node from the provided PDL module.
179 std::unique_ptr
<MatcherNode
> root
= MatcherNode::generateMatcherTree(
180 module
, predicateBuilder
, valueToPosition
);
181 Block
*firstMatcherBlock
= generateMatcher(*root
, matcherFunc
.getBody());
182 assert(failureBlockStack
.empty() && "failed to empty the stack");
184 // After generation, merged the first matched block into the entry.
185 matcherEntryBlock
->getOperations().splice(matcherEntryBlock
->end(),
186 firstMatcherBlock
->getOperations());
187 firstMatcherBlock
->erase();
190 Block
*PatternLowering::generateMatcher(MatcherNode
&node
, Region
®ion
,
192 // Push a new scope for the values used by this matcher.
194 block
= ®ion
.emplaceBlock();
195 ValueMapScope
scope(values
);
197 // If this is the return node, simply insert the corresponding interpreter
199 if (isa
<ExitNode
>(node
)) {
200 builder
.setInsertionPointToEnd(block
);
201 builder
.create
<pdl_interp::FinalizeOp
>(matcherFunc
.getLoc());
205 // Get the next block in the match sequence.
206 // This is intentionally executed first, before we get the value for the
207 // position associated with the node, so that we preserve an "there exist"
208 // semantics: if getting a value requires an upward traversal (going from a
209 // value to its consumers), we want to perform the check on all the consumers
210 // before we pass control to the failure node.
211 std::unique_ptr
<MatcherNode
> &failureNode
= node
.getFailureNode();
214 failureBlock
= generateMatcher(*failureNode
, region
);
215 failureBlockStack
.push_back(failureBlock
);
217 assert(!failureBlockStack
.empty() && "expected valid failure block");
218 failureBlock
= failureBlockStack
.back();
221 // If this node contains a position, get the corresponding value for this
223 Block
*currentBlock
= block
;
224 Position
*position
= node
.getPosition();
225 Value val
= position
? getValueAt(currentBlock
, position
) : Value();
227 // If this value corresponds to an operation, record that we are going to use
228 // its location as part of a fused location.
229 bool isOperationValue
= val
&& isa
<pdl::OperationType
>(val
.getType());
230 if (isOperationValue
)
233 // Dispatch to the correct method based on derived node type.
234 TypeSwitch
<MatcherNode
*>(&node
)
235 .Case
<BoolNode
, SwitchNode
>([&](auto *derivedNode
) {
236 this->generate(derivedNode
, currentBlock
, val
);
238 .Case([&](SuccessNode
*successNode
) {
239 generate(successNode
, currentBlock
);
242 // Pop all the failure blocks that were inserted due to nesting of
243 // pdl_interp.iterate.
244 while (failureBlockStack
.back() != failureBlock
) {
245 failureBlockStack
.pop_back();
246 assert(!failureBlockStack
.empty() && "unable to locate failure block");
249 // Pop the new failure block.
251 failureBlockStack
.pop_back();
253 if (isOperationValue
)
259 Value
PatternLowering::getValueAt(Block
*¤tBlock
, Position
*pos
) {
260 if (Value val
= values
.lookup(pos
))
263 // Get the value for the parent position.
265 if (Position
*parent
= pos
->getParent())
266 parentVal
= getValueAt(currentBlock
, parent
);
268 // TODO: Use a location from the position.
269 Location loc
= parentVal
? parentVal
.getLoc() : builder
.getUnknownLoc();
270 builder
.setInsertionPointToEnd(currentBlock
);
272 switch (pos
->getKind()) {
273 case Predicates::OperationPos
: {
274 auto *operationPos
= cast
<OperationPosition
>(pos
);
275 if (operationPos
->isOperandDefiningOp())
276 // Standard (downward) traversal which directly follows the defining op.
277 value
= builder
.create
<pdl_interp::GetDefiningOpOp
>(
278 loc
, builder
.getType
<pdl::OperationType
>(), parentVal
);
280 // A passthrough operation position.
284 case Predicates::UsersPos
: {
285 auto *usersPos
= cast
<UsersPosition
>(pos
);
287 // The first operation retrieves the representative value of a range.
288 // This applies only when the parent is a range of values and we were
289 // requested to use a representative value (e.g., upward traversal).
290 if (isa
<pdl::RangeType
>(parentVal
.getType()) &&
291 usersPos
->useRepresentative())
292 value
= builder
.create
<pdl_interp::ExtractOp
>(loc
, parentVal
, 0);
296 // The second operation retrieves the users.
297 value
= builder
.create
<pdl_interp::GetUsersOp
>(loc
, value
);
300 case Predicates::ForEachPos
: {
301 assert(!failureBlockStack
.empty() && "expected valid failure block");
302 auto foreach
= builder
.create
<pdl_interp::ForEachOp
>(
303 loc
, parentVal
, failureBlockStack
.back(), /*initLoop=*/true);
304 value
= foreach
.getLoopVariable();
306 // Create the continuation block.
307 Block
*continueBlock
= builder
.createBlock(&foreach
.getRegion());
308 builder
.create
<pdl_interp::ContinueOp
>(loc
);
309 failureBlockStack
.push_back(continueBlock
);
311 currentBlock
= &foreach
.getRegion().front();
314 case Predicates::OperandPos
: {
315 auto *operandPos
= cast
<OperandPosition
>(pos
);
316 value
= builder
.create
<pdl_interp::GetOperandOp
>(
317 loc
, builder
.getType
<pdl::ValueType
>(), parentVal
,
318 operandPos
->getOperandNumber());
321 case Predicates::OperandGroupPos
: {
322 auto *operandPos
= cast
<OperandGroupPosition
>(pos
);
323 Type valueTy
= builder
.getType
<pdl::ValueType
>();
324 value
= builder
.create
<pdl_interp::GetOperandsOp
>(
325 loc
, operandPos
->isVariadic() ? pdl::RangeType::get(valueTy
) : valueTy
,
326 parentVal
, operandPos
->getOperandGroupNumber());
329 case Predicates::AttributePos
: {
330 auto *attrPos
= cast
<AttributePosition
>(pos
);
331 value
= builder
.create
<pdl_interp::GetAttributeOp
>(
332 loc
, builder
.getType
<pdl::AttributeType
>(), parentVal
,
333 attrPos
->getName().strref());
336 case Predicates::TypePos
: {
337 if (isa
<pdl::AttributeType
>(parentVal
.getType()))
338 value
= builder
.create
<pdl_interp::GetAttributeTypeOp
>(loc
, parentVal
);
340 value
= builder
.create
<pdl_interp::GetValueTypeOp
>(loc
, parentVal
);
343 case Predicates::ResultPos
: {
344 auto *resPos
= cast
<ResultPosition
>(pos
);
345 value
= builder
.create
<pdl_interp::GetResultOp
>(
346 loc
, builder
.getType
<pdl::ValueType
>(), parentVal
,
347 resPos
->getResultNumber());
350 case Predicates::ResultGroupPos
: {
351 auto *resPos
= cast
<ResultGroupPosition
>(pos
);
352 Type valueTy
= builder
.getType
<pdl::ValueType
>();
353 value
= builder
.create
<pdl_interp::GetResultsOp
>(
354 loc
, resPos
->isVariadic() ? pdl::RangeType::get(valueTy
) : valueTy
,
355 parentVal
, resPos
->getResultGroupNumber());
358 case Predicates::AttributeLiteralPos
: {
359 auto *attrPos
= cast
<AttributeLiteralPosition
>(pos
);
361 builder
.create
<pdl_interp::CreateAttributeOp
>(loc
, attrPos
->getValue());
364 case Predicates::TypeLiteralPos
: {
365 auto *typePos
= cast
<TypeLiteralPosition
>(pos
);
366 Attribute rawTypeAttr
= typePos
->getValue();
367 if (TypeAttr typeAttr
= dyn_cast
<TypeAttr
>(rawTypeAttr
))
368 value
= builder
.create
<pdl_interp::CreateTypeOp
>(loc
, typeAttr
);
370 value
= builder
.create
<pdl_interp::CreateTypesOp
>(
371 loc
, cast
<ArrayAttr
>(rawTypeAttr
));
374 case Predicates::ConstraintResultPos
: {
375 // Due to the order of traversal, the ApplyConstraintOp has already been
376 // created and we can find it in constraintOpMap.
377 auto *constrResPos
= cast
<ConstraintPosition
>(pos
);
378 auto i
= constraintOpMap
.find(constrResPos
->getQuestion());
379 assert(i
!= constraintOpMap
.end());
380 value
= i
->second
->getResult(constrResPos
->getIndex());
384 llvm_unreachable("Generating unknown Position getter");
388 values
.insert(pos
, value
);
392 void PatternLowering::generate(BoolNode
*boolNode
, Block
*¤tBlock
,
394 Location loc
= val
.getLoc();
395 Qualifier
*question
= boolNode
->getQuestion();
396 Qualifier
*answer
= boolNode
->getAnswer();
397 Region
*region
= currentBlock
->getParent();
399 // Execute the getValue queries first, so that we create success
400 // matcher in the correct (possibly nested) region.
401 SmallVector
<Value
> args
;
402 if (auto *equalToQuestion
= dyn_cast
<EqualToQuestion
>(question
)) {
403 args
= {getValueAt(currentBlock
, equalToQuestion
->getValue())};
404 } else if (auto *cstQuestion
= dyn_cast
<ConstraintQuestion
>(question
)) {
405 for (Position
*position
: cstQuestion
->getArgs())
406 args
.push_back(getValueAt(currentBlock
, position
));
409 // Generate a new block as success successor and get the failure successor.
410 Block
*success
= ®ion
->emplaceBlock();
411 Block
*failure
= failureBlockStack
.back();
413 // Create the predicate.
414 builder
.setInsertionPointToEnd(currentBlock
);
415 Predicates::Kind kind
= question
->getKind();
417 case Predicates::IsNotNullQuestion
:
418 builder
.create
<pdl_interp::IsNotNullOp
>(loc
, val
, success
, failure
);
420 case Predicates::OperationNameQuestion
: {
421 auto *opNameAnswer
= cast
<OperationNameAnswer
>(answer
);
422 builder
.create
<pdl_interp::CheckOperationNameOp
>(
423 loc
, val
, opNameAnswer
->getValue().getStringRef(), success
, failure
);
426 case Predicates::TypeQuestion
: {
427 auto *ans
= cast
<TypeAnswer
>(answer
);
428 if (isa
<pdl::RangeType
>(val
.getType()))
429 builder
.create
<pdl_interp::CheckTypesOp
>(
430 loc
, val
, llvm::cast
<ArrayAttr
>(ans
->getValue()), success
, failure
);
432 builder
.create
<pdl_interp::CheckTypeOp
>(
433 loc
, val
, llvm::cast
<TypeAttr
>(ans
->getValue()), success
, failure
);
436 case Predicates::AttributeQuestion
: {
437 auto *ans
= cast
<AttributeAnswer
>(answer
);
438 builder
.create
<pdl_interp::CheckAttributeOp
>(loc
, val
, ans
->getValue(),
442 case Predicates::OperandCountAtLeastQuestion
:
443 case Predicates::OperandCountQuestion
:
444 builder
.create
<pdl_interp::CheckOperandCountOp
>(
445 loc
, val
, cast
<UnsignedAnswer
>(answer
)->getValue(),
446 /*compareAtLeast=*/kind
== Predicates::OperandCountAtLeastQuestion
,
449 case Predicates::ResultCountAtLeastQuestion
:
450 case Predicates::ResultCountQuestion
:
451 builder
.create
<pdl_interp::CheckResultCountOp
>(
452 loc
, val
, cast
<UnsignedAnswer
>(answer
)->getValue(),
453 /*compareAtLeast=*/kind
== Predicates::ResultCountAtLeastQuestion
,
456 case Predicates::EqualToQuestion
: {
457 bool trueAnswer
= isa
<TrueAnswer
>(answer
);
458 builder
.create
<pdl_interp::AreEqualOp
>(loc
, val
, args
.front(),
459 trueAnswer
? success
: failure
,
460 trueAnswer
? failure
: success
);
463 case Predicates::ConstraintQuestion
: {
464 auto *cstQuestion
= cast
<ConstraintQuestion
>(question
);
465 auto applyConstraintOp
= builder
.create
<pdl_interp::ApplyConstraintOp
>(
466 loc
, cstQuestion
->getResultTypes(), cstQuestion
->getName(), args
,
467 cstQuestion
->getIsNegated(), success
, failure
);
469 constraintOpMap
.insert({cstQuestion
, applyConstraintOp
});
473 llvm_unreachable("Generating unknown Predicate operation");
476 // Generate the matcher in the current (potentially nested) region.
477 // This might use the results of the current predicate.
478 generateMatcher(*boolNode
->getSuccessNode(), *region
, success
);
481 template <typename OpT
, typename PredT
, typename ValT
= typename
PredT::KeyTy
>
482 static void createSwitchOp(Value val
, Block
*defaultDest
, OpBuilder
&builder
,
483 llvm::MapVector
<Qualifier
*, Block
*> &dests
) {
484 std::vector
<ValT
> values
;
485 std::vector
<Block
*> blocks
;
486 values
.reserve(dests
.size());
487 blocks
.reserve(dests
.size());
488 for (const auto &it
: dests
) {
489 blocks
.push_back(it
.second
);
490 values
.push_back(cast
<PredT
>(it
.first
)->getValue());
492 builder
.create
<OpT
>(val
.getLoc(), val
, values
, defaultDest
, blocks
);
495 void PatternLowering::generate(SwitchNode
*switchNode
, Block
*currentBlock
,
497 Qualifier
*question
= switchNode
->getQuestion();
498 Region
*region
= currentBlock
->getParent();
499 Block
*defaultDest
= failureBlockStack
.back();
501 // If the switch question is not an exact answer, i.e. for the `at_least`
502 // cases, we generate a special block sequence.
503 Predicates::Kind kind
= question
->getKind();
504 if (kind
== Predicates::OperandCountAtLeastQuestion
||
505 kind
== Predicates::ResultCountAtLeastQuestion
) {
506 // Order the children such that the cases are in reverse numerical order.
507 SmallVector
<unsigned> sortedChildren
= llvm::to_vector
<16>(
508 llvm::seq
<unsigned>(0, switchNode
->getChildren().size()));
509 llvm::sort(sortedChildren
, [&](unsigned lhs
, unsigned rhs
) {
510 return cast
<UnsignedAnswer
>(switchNode
->getChild(lhs
).first
)->getValue() >
511 cast
<UnsignedAnswer
>(switchNode
->getChild(rhs
).first
)->getValue();
514 // Build the destination for each child using the next highest child as a
515 // a failure destination. This essentially creates the following control
518 // if (operand_count < 1)
520 // if (child1.match())
523 // if (operand_count < 2)
525 // if (child2.match())
531 failureBlockStack
.push_back(defaultDest
);
532 Location loc
= val
.getLoc();
533 for (unsigned idx
: sortedChildren
) {
534 auto &child
= switchNode
->getChild(idx
);
535 Block
*childBlock
= generateMatcher(*child
.second
, *region
);
536 Block
*predicateBlock
= builder
.createBlock(childBlock
);
537 builder
.setInsertionPointToEnd(predicateBlock
);
538 unsigned ans
= cast
<UnsignedAnswer
>(child
.first
)->getValue();
540 case Predicates::OperandCountAtLeastQuestion
:
541 builder
.create
<pdl_interp::CheckOperandCountOp
>(
542 loc
, val
, ans
, /*compareAtLeast=*/true, childBlock
, defaultDest
);
544 case Predicates::ResultCountAtLeastQuestion
:
545 builder
.create
<pdl_interp::CheckResultCountOp
>(
546 loc
, val
, ans
, /*compareAtLeast=*/true, childBlock
, defaultDest
);
549 llvm_unreachable("Generating invalid AtLeast operation");
551 failureBlockStack
.back() = predicateBlock
;
553 Block
*firstPredicateBlock
= failureBlockStack
.pop_back_val();
554 currentBlock
->getOperations().splice(currentBlock
->end(),
555 firstPredicateBlock
->getOperations());
556 firstPredicateBlock
->erase();
560 // Otherwise, generate each of the children and generate an interpreter
562 llvm::MapVector
<Qualifier
*, Block
*> children
;
563 for (auto &it
: switchNode
->getChildren())
564 children
.insert({it
.first
, generateMatcher(*it
.second
, *region
)});
565 builder
.setInsertionPointToEnd(currentBlock
);
567 switch (question
->getKind()) {
568 case Predicates::OperandCountQuestion
:
569 return createSwitchOp
<pdl_interp::SwitchOperandCountOp
, UnsignedAnswer
,
570 int32_t>(val
, defaultDest
, builder
, children
);
571 case Predicates::ResultCountQuestion
:
572 return createSwitchOp
<pdl_interp::SwitchResultCountOp
, UnsignedAnswer
,
573 int32_t>(val
, defaultDest
, builder
, children
);
574 case Predicates::OperationNameQuestion
:
575 return createSwitchOp
<pdl_interp::SwitchOperationNameOp
,
576 OperationNameAnswer
>(val
, defaultDest
, builder
,
578 case Predicates::TypeQuestion
:
579 if (isa
<pdl::RangeType
>(val
.getType())) {
580 return createSwitchOp
<pdl_interp::SwitchTypesOp
, TypeAnswer
>(
581 val
, defaultDest
, builder
, children
);
583 return createSwitchOp
<pdl_interp::SwitchTypeOp
, TypeAnswer
>(
584 val
, defaultDest
, builder
, children
);
585 case Predicates::AttributeQuestion
:
586 return createSwitchOp
<pdl_interp::SwitchAttributeOp
, AttributeAnswer
>(
587 val
, defaultDest
, builder
, children
);
589 llvm_unreachable("Generating unknown switch predicate.");
593 void PatternLowering::generate(SuccessNode
*successNode
, Block
*¤tBlock
) {
594 pdl::PatternOp pattern
= successNode
->getPattern();
595 Value root
= successNode
->getRoot();
597 // Generate a rewriter for the pattern this success node represents, and track
598 // any values used from the match region.
599 SmallVector
<Position
*, 8> usedMatchValues
;
600 SymbolRefAttr rewriterFuncRef
= generateRewriter(pattern
, usedMatchValues
);
602 // Process any values used in the rewrite that are defined in the match.
603 std::vector
<Value
> mappedMatchValues
;
604 mappedMatchValues
.reserve(usedMatchValues
.size());
605 for (Position
*position
: usedMatchValues
)
606 mappedMatchValues
.push_back(getValueAt(currentBlock
, position
));
608 // Collect the set of operations generated by the rewriter.
609 SmallVector
<StringRef
, 4> generatedOps
;
611 pattern
.getRewriter().getBodyRegion().getOps
<pdl::OperationOp
>())
612 generatedOps
.push_back(*op
.getOpName());
613 ArrayAttr generatedOpsAttr
;
614 if (!generatedOps
.empty())
615 generatedOpsAttr
= builder
.getStrArrayAttr(generatedOps
);
617 // Grab the root kind if present.
618 StringAttr rootKindAttr
;
619 if (pdl::OperationOp rootOp
= root
.getDefiningOp
<pdl::OperationOp
>())
620 if (std::optional
<StringRef
> rootKind
= rootOp
.getOpName())
621 rootKindAttr
= builder
.getStringAttr(*rootKind
);
623 builder
.setInsertionPointToEnd(currentBlock
);
624 auto matchOp
= builder
.create
<pdl_interp::RecordMatchOp
>(
625 pattern
.getLoc(), mappedMatchValues
, locOps
.getArrayRef(),
626 rewriterFuncRef
, rootKindAttr
, generatedOpsAttr
, pattern
.getBenefitAttr(),
627 failureBlockStack
.back());
629 // Set the config of the lowered match to the parent pattern.
631 configMap
->try_emplace(matchOp
, configMap
->lookup(pattern
));
634 SymbolRefAttr
PatternLowering::generateRewriter(
635 pdl::PatternOp pattern
, SmallVectorImpl
<Position
*> &usedMatchValues
) {
636 builder
.setInsertionPointToEnd(rewriterModule
.getBody());
637 auto rewriterFunc
= builder
.create
<pdl_interp::FuncOp
>(
638 pattern
.getLoc(), "pdl_generated_rewriter",
639 builder
.getFunctionType(std::nullopt
, std::nullopt
));
640 rewriterSymbolTable
.insert(rewriterFunc
);
642 // Generate the rewriter function body.
643 builder
.setInsertionPointToEnd(&rewriterFunc
.front());
645 // Map an input operand of the pattern to a generated interpreter value.
646 DenseMap
<Value
, Value
> rewriteValues
;
647 auto mapRewriteValue
= [&](Value oldValue
) {
648 Value
&newValue
= rewriteValues
[oldValue
];
652 // Prefer materializing constants directly when possible.
653 Operation
*oldOp
= oldValue
.getDefiningOp();
654 if (pdl::AttributeOp attrOp
= dyn_cast
<pdl::AttributeOp
>(oldOp
)) {
655 if (Attribute value
= attrOp
.getValueAttr()) {
656 return newValue
= builder
.create
<pdl_interp::CreateAttributeOp
>(
657 attrOp
.getLoc(), value
);
659 } else if (pdl::TypeOp typeOp
= dyn_cast
<pdl::TypeOp
>(oldOp
)) {
660 if (TypeAttr type
= typeOp
.getConstantTypeAttr()) {
661 return newValue
= builder
.create
<pdl_interp::CreateTypeOp
>(
662 typeOp
.getLoc(), type
);
664 } else if (pdl::TypesOp typeOp
= dyn_cast
<pdl::TypesOp
>(oldOp
)) {
665 if (ArrayAttr type
= typeOp
.getConstantTypesAttr()) {
666 return newValue
= builder
.create
<pdl_interp::CreateTypesOp
>(
667 typeOp
.getLoc(), typeOp
.getType(), type
);
671 // Otherwise, add this as an input to the rewriter.
672 Position
*inputPos
= valueToPosition
.lookup(oldValue
);
673 assert(inputPos
&& "expected value to be a pattern input");
674 usedMatchValues
.push_back(inputPos
);
675 return newValue
= rewriterFunc
.front().addArgument(oldValue
.getType(),
679 // If this is a custom rewriter, simply dispatch to the registered rewrite
681 pdl::RewriteOp rewriter
= pattern
.getRewriter();
682 if (StringAttr rewriteName
= rewriter
.getNameAttr()) {
683 SmallVector
<Value
> args
;
684 if (rewriter
.getRoot())
685 args
.push_back(mapRewriteValue(rewriter
.getRoot()));
687 llvm::map_range(rewriter
.getExternalArgs(), mapRewriteValue
);
688 args
.append(mappedArgs
.begin(), mappedArgs
.end());
689 builder
.create
<pdl_interp::ApplyRewriteOp
>(
690 rewriter
.getLoc(), /*resultTypes=*/TypeRange(), rewriteName
, args
);
692 // Otherwise this is a dag rewriter defined using PDL operations.
693 for (Operation
&rewriteOp
: *rewriter
.getBody()) {
694 llvm::TypeSwitch
<Operation
*>(&rewriteOp
)
695 .Case
<pdl::ApplyNativeRewriteOp
, pdl::AttributeOp
, pdl::EraseOp
,
696 pdl::OperationOp
, pdl::RangeOp
, pdl::ReplaceOp
, pdl::ResultOp
,
697 pdl::ResultsOp
, pdl::TypeOp
, pdl::TypesOp
>([&](auto op
) {
698 this->generateRewriter(op
, rewriteValues
, mapRewriteValue
);
703 // Update the signature of the rewrite function.
704 rewriterFunc
.setType(builder
.getFunctionType(
705 llvm::to_vector
<8>(rewriterFunc
.front().getArgumentTypes()),
706 /*results=*/std::nullopt
));
708 builder
.create
<pdl_interp::FinalizeOp
>(rewriter
.getLoc());
709 return SymbolRefAttr::get(
710 builder
.getContext(),
711 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712 SymbolRefAttr::get(rewriterFunc
));
715 void PatternLowering::generateRewriter(
716 pdl::ApplyNativeRewriteOp rewriteOp
, DenseMap
<Value
, Value
> &rewriteValues
,
717 function_ref
<Value(Value
)> mapRewriteValue
) {
718 SmallVector
<Value
, 2> arguments
;
719 for (Value argument
: rewriteOp
.getArgs())
720 arguments
.push_back(mapRewriteValue(argument
));
721 auto interpOp
= builder
.create
<pdl_interp::ApplyRewriteOp
>(
722 rewriteOp
.getLoc(), rewriteOp
.getResultTypes(), rewriteOp
.getNameAttr(),
724 for (auto it
: llvm::zip(rewriteOp
.getResults(), interpOp
.getResults()))
725 rewriteValues
[std::get
<0>(it
)] = std::get
<1>(it
);
728 void PatternLowering::generateRewriter(
729 pdl::AttributeOp attrOp
, DenseMap
<Value
, Value
> &rewriteValues
,
730 function_ref
<Value(Value
)> mapRewriteValue
) {
731 Value newAttr
= builder
.create
<pdl_interp::CreateAttributeOp
>(
732 attrOp
.getLoc(), attrOp
.getValueAttr());
733 rewriteValues
[attrOp
] = newAttr
;
736 void PatternLowering::generateRewriter(
737 pdl::EraseOp eraseOp
, DenseMap
<Value
, Value
> &rewriteValues
,
738 function_ref
<Value(Value
)> mapRewriteValue
) {
739 builder
.create
<pdl_interp::EraseOp
>(eraseOp
.getLoc(),
740 mapRewriteValue(eraseOp
.getOpValue()));
743 void PatternLowering::generateRewriter(
744 pdl::OperationOp operationOp
, DenseMap
<Value
, Value
> &rewriteValues
,
745 function_ref
<Value(Value
)> mapRewriteValue
) {
746 SmallVector
<Value
, 4> operands
;
747 for (Value operand
: operationOp
.getOperandValues())
748 operands
.push_back(mapRewriteValue(operand
));
750 SmallVector
<Value
, 4> attributes
;
751 for (Value attr
: operationOp
.getAttributeValues())
752 attributes
.push_back(mapRewriteValue(attr
));
754 bool hasInferredResultTypes
= false;
755 SmallVector
<Value
, 2> types
;
756 generateOperationResultTypeRewriter(operationOp
, mapRewriteValue
, types
,
757 rewriteValues
, hasInferredResultTypes
);
759 // Create the new operation.
760 Location loc
= operationOp
.getLoc();
761 Value createdOp
= builder
.create
<pdl_interp::CreateOperationOp
>(
762 loc
, *operationOp
.getOpName(), types
, hasInferredResultTypes
, operands
,
763 attributes
, operationOp
.getAttributeValueNames());
764 rewriteValues
[operationOp
.getOp()] = createdOp
;
766 // Generate accesses for any results that have their types constrained.
767 // Handle the case where there is a single range representing all of the
769 OperandRange resultTys
= operationOp
.getTypeValues();
770 if (resultTys
.size() == 1 && isa
<pdl::RangeType
>(resultTys
[0].getType())) {
771 Value
&type
= rewriteValues
[resultTys
[0]];
773 auto results
= builder
.create
<pdl_interp::GetResultsOp
>(loc
, createdOp
);
774 type
= builder
.create
<pdl_interp::GetValueTypeOp
>(loc
, results
);
779 // Otherwise, populate the individual results.
780 bool seenVariableLength
= false;
781 Type valueTy
= builder
.getType
<pdl::ValueType
>();
782 Type valueRangeTy
= pdl::RangeType::get(valueTy
);
783 for (const auto &it
: llvm::enumerate(resultTys
)) {
784 Value
&type
= rewriteValues
[it
.value()];
787 bool isVariadic
= isa
<pdl::RangeType
>(it
.value().getType());
788 seenVariableLength
|= isVariadic
;
790 // After a variable length result has been seen, we need to use result
791 // groups because the exact index of the result is not statically known.
793 if (seenVariableLength
)
794 resultVal
= builder
.create
<pdl_interp::GetResultsOp
>(
795 loc
, isVariadic
? valueRangeTy
: valueTy
, createdOp
, it
.index());
797 resultVal
= builder
.create
<pdl_interp::GetResultOp
>(
798 loc
, valueTy
, createdOp
, it
.index());
799 type
= builder
.create
<pdl_interp::GetValueTypeOp
>(loc
, resultVal
);
803 void PatternLowering::generateRewriter(
804 pdl::RangeOp rangeOp
, DenseMap
<Value
, Value
> &rewriteValues
,
805 function_ref
<Value(Value
)> mapRewriteValue
) {
806 SmallVector
<Value
, 4> replOperands
;
807 for (Value operand
: rangeOp
.getArguments())
808 replOperands
.push_back(mapRewriteValue(operand
));
809 rewriteValues
[rangeOp
] = builder
.create
<pdl_interp::CreateRangeOp
>(
810 rangeOp
.getLoc(), rangeOp
.getType(), replOperands
);
813 void PatternLowering::generateRewriter(
814 pdl::ReplaceOp replaceOp
, DenseMap
<Value
, Value
> &rewriteValues
,
815 function_ref
<Value(Value
)> mapRewriteValue
) {
816 SmallVector
<Value
, 4> replOperands
;
818 // If the replacement was another operation, get its results. `pdl` allows
819 // for using an operation for simplicitly, but the interpreter isn't as
821 if (Value replOp
= replaceOp
.getReplOperation()) {
822 // Don't use replace if we know the replaced operation has no results.
823 auto opOp
= replaceOp
.getOpValue().getDefiningOp
<pdl::OperationOp
>();
824 if (!opOp
|| !opOp
.getTypeValues().empty()) {
825 replOperands
.push_back(builder
.create
<pdl_interp::GetResultsOp
>(
826 replOp
.getLoc(), mapRewriteValue(replOp
)));
829 for (Value operand
: replaceOp
.getReplValues())
830 replOperands
.push_back(mapRewriteValue(operand
));
833 // If there are no replacement values, just create an erase instead.
834 if (replOperands
.empty()) {
835 builder
.create
<pdl_interp::EraseOp
>(
836 replaceOp
.getLoc(), mapRewriteValue(replaceOp
.getOpValue()));
840 builder
.create
<pdl_interp::ReplaceOp
>(replaceOp
.getLoc(),
841 mapRewriteValue(replaceOp
.getOpValue()),
845 void PatternLowering::generateRewriter(
846 pdl::ResultOp resultOp
, DenseMap
<Value
, Value
> &rewriteValues
,
847 function_ref
<Value(Value
)> mapRewriteValue
) {
848 rewriteValues
[resultOp
] = builder
.create
<pdl_interp::GetResultOp
>(
849 resultOp
.getLoc(), builder
.getType
<pdl::ValueType
>(),
850 mapRewriteValue(resultOp
.getParent()), resultOp
.getIndex());
853 void PatternLowering::generateRewriter(
854 pdl::ResultsOp resultOp
, DenseMap
<Value
, Value
> &rewriteValues
,
855 function_ref
<Value(Value
)> mapRewriteValue
) {
856 rewriteValues
[resultOp
] = builder
.create
<pdl_interp::GetResultsOp
>(
857 resultOp
.getLoc(), resultOp
.getType(),
858 mapRewriteValue(resultOp
.getParent()), resultOp
.getIndex());
861 void PatternLowering::generateRewriter(
862 pdl::TypeOp typeOp
, DenseMap
<Value
, Value
> &rewriteValues
,
863 function_ref
<Value(Value
)> mapRewriteValue
) {
864 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
866 if (TypeAttr typeAttr
= typeOp
.getConstantTypeAttr()) {
867 rewriteValues
[typeOp
] =
868 builder
.create
<pdl_interp::CreateTypeOp
>(typeOp
.getLoc(), typeAttr
);
872 void PatternLowering::generateRewriter(
873 pdl::TypesOp typeOp
, DenseMap
<Value
, Value
> &rewriteValues
,
874 function_ref
<Value(Value
)> mapRewriteValue
) {
875 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
877 if (ArrayAttr typeAttr
= typeOp
.getConstantTypesAttr()) {
878 rewriteValues
[typeOp
] = builder
.create
<pdl_interp::CreateTypesOp
>(
879 typeOp
.getLoc(), typeOp
.getType(), typeAttr
);
883 void PatternLowering::generateOperationResultTypeRewriter(
884 pdl::OperationOp op
, function_ref
<Value(Value
)> mapRewriteValue
,
885 SmallVectorImpl
<Value
> &types
, DenseMap
<Value
, Value
> &rewriteValues
,
886 bool &hasInferredResultTypes
) {
887 Block
*rewriterBlock
= op
->getBlock();
889 // Try to handle resolution for each of the result types individually. This is
890 // preferred over type inferrence because it will allow for us to use existing
891 // types directly, as opposed to trying to rebuild the type list.
892 OperandRange resultTypeValues
= op
.getTypeValues();
893 auto tryResolveResultTypes
= [&] {
894 types
.reserve(resultTypeValues
.size());
895 for (const auto &it
: llvm::enumerate(resultTypeValues
)) {
896 Value resultType
= it
.value();
898 // Check for an already translated value.
899 if (Value existingRewriteValue
= rewriteValues
.lookup(resultType
)) {
900 types
.push_back(existingRewriteValue
);
904 // Check for an input from the matcher.
905 if (resultType
.getDefiningOp()->getBlock() != rewriterBlock
) {
906 types
.push_back(mapRewriteValue(resultType
));
910 // Otherwise, we couldn't infer the result types. Bail out here to see if
911 // we can infer the types for this operation from another way.
917 if (!resultTypeValues
.empty() && succeeded(tryResolveResultTypes()))
920 // Otherwise, check if the operation has type inference support itself.
921 if (op
.hasTypeInference()) {
922 hasInferredResultTypes
= true;
926 // Look for an operation that was replaced by `op`. The result types will be
927 // inferred from the results that were replaced.
928 for (OpOperand
&use
: op
.getOp().getUses()) {
929 // Check that the use corresponds to a ReplaceOp and that it is the
930 // replacement value, not the operation being replaced.
931 pdl::ReplaceOp replOpUser
= dyn_cast
<pdl::ReplaceOp
>(use
.getOwner());
932 if (!replOpUser
|| use
.getOperandNumber() == 0)
934 // Make sure the replaced operation was defined before this one. PDL
935 // rewrites only have single block regions, so if the op isn't in the
936 // rewriter block (i.e. the current block of the operation) we already know
937 // it dominates (i.e. it's in the matcher).
938 Value replOpVal
= replOpUser
.getOpValue();
939 Operation
*replacedOp
= replOpVal
.getDefiningOp();
940 if (replacedOp
->getBlock() == rewriterBlock
&&
941 !replacedOp
->isBeforeInBlock(op
))
944 Value replacedOpResults
= builder
.create
<pdl_interp::GetResultsOp
>(
945 replacedOp
->getLoc(), mapRewriteValue(replOpVal
));
946 types
.push_back(builder
.create
<pdl_interp::GetValueTypeOp
>(
947 replacedOp
->getLoc(), replacedOpResults
));
951 // If the types could not be inferred from any context and there weren't any
952 // explicit result types, assume the user actually meant for the operation to
954 if (resultTypeValues
.empty())
957 // The verifier asserts that the result types of each pdl.getOperation can be
958 // inferred. If we reach here, there is a bug either in the logic above or
959 // in the verifier for pdl.getOperation.
960 op
->emitOpError() << "unable to infer result type for operation";
961 llvm_unreachable("unable to infer result type for operation");
964 //===----------------------------------------------------------------------===//
966 //===----------------------------------------------------------------------===//
969 struct PDLToPDLInterpPass
970 : public impl::ConvertPDLToPDLInterpBase
<PDLToPDLInterpPass
> {
971 PDLToPDLInterpPass() = default;
972 PDLToPDLInterpPass(const PDLToPDLInterpPass
&rhs
) = default;
973 PDLToPDLInterpPass(DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
)
974 : configMap(&configMap
) {}
975 void runOnOperation() final
;
977 /// A map containing the configuration for each pattern.
978 DenseMap
<Operation
*, PDLPatternConfigSet
*> *configMap
= nullptr;
982 /// Convert the given module containing PDL pattern operations into a PDL
983 /// Interpreter operations.
984 void PDLToPDLInterpPass::runOnOperation() {
985 ModuleOp module
= getOperation();
987 // Create the main matcher function This function contains all of the match
988 // related functionality from patterns in the module.
989 OpBuilder builder
= OpBuilder::atBlockBegin(module
.getBody());
990 auto matcherFunc
= builder
.create
<pdl_interp::FuncOp
>(
991 module
.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
992 builder
.getFunctionType(builder
.getType
<pdl::OperationType
>(),
993 /*results=*/std::nullopt
),
994 /*attrs=*/std::nullopt
);
996 // Create a nested module to hold the functions invoked for rewriting the IR
997 // after a successful match.
998 ModuleOp rewriterModule
= builder
.create
<ModuleOp
>(
999 module
.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
1001 // Generate the code for the patterns within the module.
1002 PatternLowering
generator(matcherFunc
, rewriterModule
, configMap
);
1003 generator
.lower(module
);
1005 // After generation, delete all of the pattern operations.
1006 for (pdl::PatternOp pattern
:
1007 llvm::make_early_inc_range(module
.getOps
<pdl::PatternOp
>())) {
1008 // Drop the now dead config mappings.
1010 configMap
->erase(pattern
);
1016 std::unique_ptr
<OperationPass
<ModuleOp
>> mlir::createPDLToPDLInterpPass() {
1017 return std::make_unique
<PDLToPDLInterpPass
>();
1019 std::unique_ptr
<OperationPass
<ModuleOp
>> mlir::createPDLToPDLInterpPass(
1020 DenseMap
<Operation
*, PDLPatternConfigSet
*> &configMap
) {
1021 return std::make_unique
<PDLToPDLInterpPass
>(configMap
);