Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / PDLToPDLInterp / PDLToPDLInterp.cpp
blobb00cd0dee3ae8091cca29a319a1d2623e27ae3d8
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
41 DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
43 /// Generate code for matching and rewriting based on the pattern operations
44 /// within the module.
45 void lower(ModuleOp module);
47 private:
48 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
51 /// Generate interpreter operations for the tree rooted at the given matcher
52 /// node, in the specified region.
53 Block *generateMatcher(MatcherNode &node, Region &region,
54 Block *block = nullptr);
56 /// Get or create an access to the provided positional value in the current
57 /// block. This operation may mutate the provided block pointer if nested
58 /// regions (i.e., pdl_interp.iterate) are required.
59 Value getValueAt(Block *&currentBlock, Position *pos);
61 /// Create the interpreter predicate operations. This operation may mutate the
62 /// provided current block pointer if nested regions (iterates) are required.
63 void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
65 /// Create the interpreter switch / predicate operations, with several case
66 /// destinations. This operation never mutates the provided current block
67 /// pointer, because the switch operation does not need Values beyond `val`.
68 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
70 /// Create the interpreter operations to record a successful pattern match
71 /// using the contained root operation. This operation may mutate the current
72 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73 void generate(SuccessNode *successNode, Block *&currentBlock);
75 /// Generate a rewriter function for the given pattern operation, and returns
76 /// a reference to that function.
77 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
78 SmallVectorImpl<Position *> &usedMatchValues);
80 /// Generate the rewriter code for the given operation.
81 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
82 DenseMap<Value, Value> &rewriteValues,
83 function_ref<Value(Value)> mapRewriteValue);
84 void generateRewriter(pdl::AttributeOp attrOp,
85 DenseMap<Value, Value> &rewriteValues,
86 function_ref<Value(Value)> mapRewriteValue);
87 void generateRewriter(pdl::EraseOp eraseOp,
88 DenseMap<Value, Value> &rewriteValues,
89 function_ref<Value(Value)> mapRewriteValue);
90 void generateRewriter(pdl::OperationOp operationOp,
91 DenseMap<Value, Value> &rewriteValues,
92 function_ref<Value(Value)> mapRewriteValue);
93 void generateRewriter(pdl::RangeOp rangeOp,
94 DenseMap<Value, Value> &rewriteValues,
95 function_ref<Value(Value)> mapRewriteValue);
96 void generateRewriter(pdl::ReplaceOp replaceOp,
97 DenseMap<Value, Value> &rewriteValues,
98 function_ref<Value(Value)> mapRewriteValue);
99 void generateRewriter(pdl::ResultOp resultOp,
100 DenseMap<Value, Value> &rewriteValues,
101 function_ref<Value(Value)> mapRewriteValue);
102 void generateRewriter(pdl::ResultsOp resultOp,
103 DenseMap<Value, Value> &rewriteValues,
104 function_ref<Value(Value)> mapRewriteValue);
105 void generateRewriter(pdl::TypeOp typeOp,
106 DenseMap<Value, Value> &rewriteValues,
107 function_ref<Value(Value)> mapRewriteValue);
108 void generateRewriter(pdl::TypesOp typeOp,
109 DenseMap<Value, Value> &rewriteValues,
110 function_ref<Value(Value)> mapRewriteValue);
112 /// Generate the values used for resolving the result types of an operation
113 /// created within a dag rewriter region. If the result types of the operation
114 /// should be inferred, `hasInferredResultTypes` is set to true.
115 void generateOperationResultTypeRewriter(
116 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
117 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
118 bool &hasInferredResultTypes);
120 /// A builder to use when generating interpreter operations.
121 OpBuilder builder;
123 /// The matcher function used for all match related logic within PDL patterns.
124 pdl_interp::FuncOp matcherFunc;
126 /// The rewriter module containing the all rewrite related logic within PDL
127 /// patterns.
128 ModuleOp rewriterModule;
130 /// The symbol table of the rewriter module used for insertion.
131 SymbolTable rewriterSymbolTable;
133 /// A scoped map connecting a position with the corresponding interpreter
134 /// value.
135 ValueMap values;
137 /// A stack of blocks used as the failure destination for matcher nodes that
138 /// don't have an explicit failure path.
139 SmallVector<Block *, 8> failureBlockStack;
141 /// A mapping between values defined in a pattern match, and the corresponding
142 /// positional value.
143 DenseMap<Value, Position *> valueToPosition;
145 /// The set of operation values whose location will be used for newly
146 /// generated operations.
147 SetVector<Value> locOps;
149 /// A mapping between pattern operations and the corresponding configuration
150 /// set.
151 DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
153 /// A mapping from a constraint question to the ApplyConstraintOp
154 /// that implements it.
155 DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
157 } // namespace
159 PatternLowering::PatternLowering(
160 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
161 DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
162 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
163 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
164 configMap(configMap) {}
166 void PatternLowering::lower(ModuleOp module) {
167 PredicateUniquer predicateUniquer;
168 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
170 // Define top-level scope for the arguments to the matcher function.
171 ValueMapScope topLevelValueScope(values);
173 // Insert the root operation, i.e. argument to the matcher, at the root
174 // position.
175 Block *matcherEntryBlock = &matcherFunc.front();
176 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
178 // Generate a root matcher node from the provided PDL module.
179 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
180 module, predicateBuilder, valueToPosition);
181 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
182 assert(failureBlockStack.empty() && "failed to empty the stack");
184 // After generation, merged the first matched block into the entry.
185 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
186 firstMatcherBlock->getOperations());
187 firstMatcherBlock->erase();
190 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191 Block *block) {
192 // Push a new scope for the values used by this matcher.
193 if (!block)
194 block = &region.emplaceBlock();
195 ValueMapScope scope(values);
197 // If this is the return node, simply insert the corresponding interpreter
198 // finalize.
199 if (isa<ExitNode>(node)) {
200 builder.setInsertionPointToEnd(block);
201 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
202 return block;
205 // Get the next block in the match sequence.
206 // This is intentionally executed first, before we get the value for the
207 // position associated with the node, so that we preserve an "there exist"
208 // semantics: if getting a value requires an upward traversal (going from a
209 // value to its consumers), we want to perform the check on all the consumers
210 // before we pass control to the failure node.
211 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
212 Block *failureBlock;
213 if (failureNode) {
214 failureBlock = generateMatcher(*failureNode, region);
215 failureBlockStack.push_back(failureBlock);
216 } else {
217 assert(!failureBlockStack.empty() && "expected valid failure block");
218 failureBlock = failureBlockStack.back();
221 // If this node contains a position, get the corresponding value for this
222 // block.
223 Block *currentBlock = block;
224 Position *position = node.getPosition();
225 Value val = position ? getValueAt(currentBlock, position) : Value();
227 // If this value corresponds to an operation, record that we are going to use
228 // its location as part of a fused location.
229 bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
230 if (isOperationValue)
231 locOps.insert(val);
233 // Dispatch to the correct method based on derived node type.
234 TypeSwitch<MatcherNode *>(&node)
235 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
236 this->generate(derivedNode, currentBlock, val);
238 .Case([&](SuccessNode *successNode) {
239 generate(successNode, currentBlock);
242 // Pop all the failure blocks that were inserted due to nesting of
243 // pdl_interp.iterate.
244 while (failureBlockStack.back() != failureBlock) {
245 failureBlockStack.pop_back();
246 assert(!failureBlockStack.empty() && "unable to locate failure block");
249 // Pop the new failure block.
250 if (failureNode)
251 failureBlockStack.pop_back();
253 if (isOperationValue)
254 locOps.remove(val);
256 return block;
259 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
260 if (Value val = values.lookup(pos))
261 return val;
263 // Get the value for the parent position.
264 Value parentVal;
265 if (Position *parent = pos->getParent())
266 parentVal = getValueAt(currentBlock, parent);
268 // TODO: Use a location from the position.
269 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
270 builder.setInsertionPointToEnd(currentBlock);
271 Value value;
272 switch (pos->getKind()) {
273 case Predicates::OperationPos: {
274 auto *operationPos = cast<OperationPosition>(pos);
275 if (operationPos->isOperandDefiningOp())
276 // Standard (downward) traversal which directly follows the defining op.
277 value = builder.create<pdl_interp::GetDefiningOpOp>(
278 loc, builder.getType<pdl::OperationType>(), parentVal);
279 else
280 // A passthrough operation position.
281 value = parentVal;
282 break;
284 case Predicates::UsersPos: {
285 auto *usersPos = cast<UsersPosition>(pos);
287 // The first operation retrieves the representative value of a range.
288 // This applies only when the parent is a range of values and we were
289 // requested to use a representative value (e.g., upward traversal).
290 if (isa<pdl::RangeType>(parentVal.getType()) &&
291 usersPos->useRepresentative())
292 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
293 else
294 value = parentVal;
296 // The second operation retrieves the users.
297 value = builder.create<pdl_interp::GetUsersOp>(loc, value);
298 break;
300 case Predicates::ForEachPos: {
301 assert(!failureBlockStack.empty() && "expected valid failure block");
302 auto foreach = builder.create<pdl_interp::ForEachOp>(
303 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
304 value = foreach.getLoopVariable();
306 // Create the continuation block.
307 Block *continueBlock = builder.createBlock(&foreach.getRegion());
308 builder.create<pdl_interp::ContinueOp>(loc);
309 failureBlockStack.push_back(continueBlock);
311 currentBlock = &foreach.getRegion().front();
312 break;
314 case Predicates::OperandPos: {
315 auto *operandPos = cast<OperandPosition>(pos);
316 value = builder.create<pdl_interp::GetOperandOp>(
317 loc, builder.getType<pdl::ValueType>(), parentVal,
318 operandPos->getOperandNumber());
319 break;
321 case Predicates::OperandGroupPos: {
322 auto *operandPos = cast<OperandGroupPosition>(pos);
323 Type valueTy = builder.getType<pdl::ValueType>();
324 value = builder.create<pdl_interp::GetOperandsOp>(
325 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
326 parentVal, operandPos->getOperandGroupNumber());
327 break;
329 case Predicates::AttributePos: {
330 auto *attrPos = cast<AttributePosition>(pos);
331 value = builder.create<pdl_interp::GetAttributeOp>(
332 loc, builder.getType<pdl::AttributeType>(), parentVal,
333 attrPos->getName().strref());
334 break;
336 case Predicates::TypePos: {
337 if (isa<pdl::AttributeType>(parentVal.getType()))
338 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
339 else
340 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
341 break;
343 case Predicates::ResultPos: {
344 auto *resPos = cast<ResultPosition>(pos);
345 value = builder.create<pdl_interp::GetResultOp>(
346 loc, builder.getType<pdl::ValueType>(), parentVal,
347 resPos->getResultNumber());
348 break;
350 case Predicates::ResultGroupPos: {
351 auto *resPos = cast<ResultGroupPosition>(pos);
352 Type valueTy = builder.getType<pdl::ValueType>();
353 value = builder.create<pdl_interp::GetResultsOp>(
354 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
356 break;
358 case Predicates::AttributeLiteralPos: {
359 auto *attrPos = cast<AttributeLiteralPosition>(pos);
360 value =
361 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
362 break;
364 case Predicates::TypeLiteralPos: {
365 auto *typePos = cast<TypeLiteralPosition>(pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
369 else
370 value = builder.create<pdl_interp::CreateTypesOp>(
371 loc, cast<ArrayAttr>(rawTypeAttr));
372 break;
374 case Predicates::ConstraintResultPos: {
375 // Due to the order of traversal, the ApplyConstraintOp has already been
376 // created and we can find it in constraintOpMap.
377 auto *constrResPos = cast<ConstraintPosition>(pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
381 break;
383 default:
384 llvm_unreachable("Generating unknown Position getter");
385 break;
388 values.insert(pos, value);
389 return value;
392 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393 Value val) {
394 Location loc = val.getLoc();
395 Qualifier *question = boolNode->getQuestion();
396 Qualifier *answer = boolNode->getAnswer();
397 Region *region = currentBlock->getParent();
399 // Execute the getValue queries first, so that we create success
400 // matcher in the correct (possibly nested) region.
401 SmallVector<Value> args;
402 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(getValueAt(currentBlock, position));
409 // Generate a new block as success successor and get the failure successor.
410 Block *success = &region->emplaceBlock();
411 Block *failure = failureBlockStack.back();
413 // Create the predicate.
414 builder.setInsertionPointToEnd(currentBlock);
415 Predicates::Kind kind = question->getKind();
416 switch (kind) {
417 case Predicates::IsNotNullQuestion:
418 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
419 break;
420 case Predicates::OperationNameQuestion: {
421 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422 builder.create<pdl_interp::CheckOperationNameOp>(
423 loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
424 break;
426 case Predicates::TypeQuestion: {
427 auto *ans = cast<TypeAnswer>(answer);
428 if (isa<pdl::RangeType>(val.getType()))
429 builder.create<pdl_interp::CheckTypesOp>(
430 loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
431 else
432 builder.create<pdl_interp::CheckTypeOp>(
433 loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
434 break;
436 case Predicates::AttributeQuestion: {
437 auto *ans = cast<AttributeAnswer>(answer);
438 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
439 success, failure);
440 break;
442 case Predicates::OperandCountAtLeastQuestion:
443 case Predicates::OperandCountQuestion:
444 builder.create<pdl_interp::CheckOperandCountOp>(
445 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
446 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
447 success, failure);
448 break;
449 case Predicates::ResultCountAtLeastQuestion:
450 case Predicates::ResultCountQuestion:
451 builder.create<pdl_interp::CheckResultCountOp>(
452 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
453 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
454 success, failure);
455 break;
456 case Predicates::EqualToQuestion: {
457 bool trueAnswer = isa<TrueAnswer>(answer);
458 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
459 trueAnswer ? success : failure,
460 trueAnswer ? failure : success);
461 break;
463 case Predicates::ConstraintQuestion: {
464 auto *cstQuestion = cast<ConstraintQuestion>(question);
465 auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466 loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467 cstQuestion->getIsNegated(), success, failure);
469 constraintOpMap.insert({cstQuestion, applyConstraintOp});
470 break;
472 default:
473 llvm_unreachable("Generating unknown Predicate operation");
476 // Generate the matcher in the current (potentially nested) region.
477 // This might use the results of the current predicate.
478 generateMatcher(*boolNode->getSuccessNode(), *region, success);
481 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
482 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
483 llvm::MapVector<Qualifier *, Block *> &dests) {
484 std::vector<ValT> values;
485 std::vector<Block *> blocks;
486 values.reserve(dests.size());
487 blocks.reserve(dests.size());
488 for (const auto &it : dests) {
489 blocks.push_back(it.second);
490 values.push_back(cast<PredT>(it.first)->getValue());
492 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
495 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
496 Value val) {
497 Qualifier *question = switchNode->getQuestion();
498 Region *region = currentBlock->getParent();
499 Block *defaultDest = failureBlockStack.back();
501 // If the switch question is not an exact answer, i.e. for the `at_least`
502 // cases, we generate a special block sequence.
503 Predicates::Kind kind = question->getKind();
504 if (kind == Predicates::OperandCountAtLeastQuestion ||
505 kind == Predicates::ResultCountAtLeastQuestion) {
506 // Order the children such that the cases are in reverse numerical order.
507 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
508 llvm::seq<unsigned>(0, switchNode->getChildren().size()));
509 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
510 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
511 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
514 // Build the destination for each child using the next highest child as a
515 // a failure destination. This essentially creates the following control
516 // flow:
518 // if (operand_count < 1)
519 // goto failure
520 // if (child1.match())
521 // ...
523 // if (operand_count < 2)
524 // goto failure
525 // if (child2.match())
526 // ...
528 // failure:
529 // ...
531 failureBlockStack.push_back(defaultDest);
532 Location loc = val.getLoc();
533 for (unsigned idx : sortedChildren) {
534 auto &child = switchNode->getChild(idx);
535 Block *childBlock = generateMatcher(*child.second, *region);
536 Block *predicateBlock = builder.createBlock(childBlock);
537 builder.setInsertionPointToEnd(predicateBlock);
538 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
539 switch (kind) {
540 case Predicates::OperandCountAtLeastQuestion:
541 builder.create<pdl_interp::CheckOperandCountOp>(
542 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
543 break;
544 case Predicates::ResultCountAtLeastQuestion:
545 builder.create<pdl_interp::CheckResultCountOp>(
546 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
547 break;
548 default:
549 llvm_unreachable("Generating invalid AtLeast operation");
551 failureBlockStack.back() = predicateBlock;
553 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
554 currentBlock->getOperations().splice(currentBlock->end(),
555 firstPredicateBlock->getOperations());
556 firstPredicateBlock->erase();
557 return;
560 // Otherwise, generate each of the children and generate an interpreter
561 // switch.
562 llvm::MapVector<Qualifier *, Block *> children;
563 for (auto &it : switchNode->getChildren())
564 children.insert({it.first, generateMatcher(*it.second, *region)});
565 builder.setInsertionPointToEnd(currentBlock);
567 switch (question->getKind()) {
568 case Predicates::OperandCountQuestion:
569 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
570 int32_t>(val, defaultDest, builder, children);
571 case Predicates::ResultCountQuestion:
572 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
573 int32_t>(val, defaultDest, builder, children);
574 case Predicates::OperationNameQuestion:
575 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
576 OperationNameAnswer>(val, defaultDest, builder,
577 children);
578 case Predicates::TypeQuestion:
579 if (isa<pdl::RangeType>(val.getType())) {
580 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
581 val, defaultDest, builder, children);
583 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
584 val, defaultDest, builder, children);
585 case Predicates::AttributeQuestion:
586 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
587 val, defaultDest, builder, children);
588 default:
589 llvm_unreachable("Generating unknown switch predicate.");
593 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
594 pdl::PatternOp pattern = successNode->getPattern();
595 Value root = successNode->getRoot();
597 // Generate a rewriter for the pattern this success node represents, and track
598 // any values used from the match region.
599 SmallVector<Position *, 8> usedMatchValues;
600 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
602 // Process any values used in the rewrite that are defined in the match.
603 std::vector<Value> mappedMatchValues;
604 mappedMatchValues.reserve(usedMatchValues.size());
605 for (Position *position : usedMatchValues)
606 mappedMatchValues.push_back(getValueAt(currentBlock, position));
608 // Collect the set of operations generated by the rewriter.
609 SmallVector<StringRef, 4> generatedOps;
610 for (auto op :
611 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
612 generatedOps.push_back(*op.getOpName());
613 ArrayAttr generatedOpsAttr;
614 if (!generatedOps.empty())
615 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
617 // Grab the root kind if present.
618 StringAttr rootKindAttr;
619 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
620 if (std::optional<StringRef> rootKind = rootOp.getOpName())
621 rootKindAttr = builder.getStringAttr(*rootKind);
623 builder.setInsertionPointToEnd(currentBlock);
624 auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
625 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
626 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
627 failureBlockStack.back());
629 // Set the config of the lowered match to the parent pattern.
630 if (configMap)
631 configMap->try_emplace(matchOp, configMap->lookup(pattern));
634 SymbolRefAttr PatternLowering::generateRewriter(
635 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
636 builder.setInsertionPointToEnd(rewriterModule.getBody());
637 auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
638 pattern.getLoc(), "pdl_generated_rewriter",
639 builder.getFunctionType(std::nullopt, std::nullopt));
640 rewriterSymbolTable.insert(rewriterFunc);
642 // Generate the rewriter function body.
643 builder.setInsertionPointToEnd(&rewriterFunc.front());
645 // Map an input operand of the pattern to a generated interpreter value.
646 DenseMap<Value, Value> rewriteValues;
647 auto mapRewriteValue = [&](Value oldValue) {
648 Value &newValue = rewriteValues[oldValue];
649 if (newValue)
650 return newValue;
652 // Prefer materializing constants directly when possible.
653 Operation *oldOp = oldValue.getDefiningOp();
654 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
655 if (Attribute value = attrOp.getValueAttr()) {
656 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
657 attrOp.getLoc(), value);
659 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
660 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
661 return newValue = builder.create<pdl_interp::CreateTypeOp>(
662 typeOp.getLoc(), type);
664 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
665 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
666 return newValue = builder.create<pdl_interp::CreateTypesOp>(
667 typeOp.getLoc(), typeOp.getType(), type);
671 // Otherwise, add this as an input to the rewriter.
672 Position *inputPos = valueToPosition.lookup(oldValue);
673 assert(inputPos && "expected value to be a pattern input");
674 usedMatchValues.push_back(inputPos);
675 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
676 oldValue.getLoc());
679 // If this is a custom rewriter, simply dispatch to the registered rewrite
680 // method.
681 pdl::RewriteOp rewriter = pattern.getRewriter();
682 if (StringAttr rewriteName = rewriter.getNameAttr()) {
683 SmallVector<Value> args;
684 if (rewriter.getRoot())
685 args.push_back(mapRewriteValue(rewriter.getRoot()));
686 auto mappedArgs =
687 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
688 args.append(mappedArgs.begin(), mappedArgs.end());
689 builder.create<pdl_interp::ApplyRewriteOp>(
690 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
691 } else {
692 // Otherwise this is a dag rewriter defined using PDL operations.
693 for (Operation &rewriteOp : *rewriter.getBody()) {
694 llvm::TypeSwitch<Operation *>(&rewriteOp)
695 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
696 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
697 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
698 this->generateRewriter(op, rewriteValues, mapRewriteValue);
703 // Update the signature of the rewrite function.
704 rewriterFunc.setType(builder.getFunctionType(
705 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
706 /*results=*/std::nullopt));
708 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
709 return SymbolRefAttr::get(
710 builder.getContext(),
711 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712 SymbolRefAttr::get(rewriterFunc));
715 void PatternLowering::generateRewriter(
716 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
717 function_ref<Value(Value)> mapRewriteValue) {
718 SmallVector<Value, 2> arguments;
719 for (Value argument : rewriteOp.getArgs())
720 arguments.push_back(mapRewriteValue(argument));
721 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
722 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
723 arguments);
724 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
725 rewriteValues[std::get<0>(it)] = std::get<1>(it);
728 void PatternLowering::generateRewriter(
729 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
730 function_ref<Value(Value)> mapRewriteValue) {
731 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
732 attrOp.getLoc(), attrOp.getValueAttr());
733 rewriteValues[attrOp] = newAttr;
736 void PatternLowering::generateRewriter(
737 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
738 function_ref<Value(Value)> mapRewriteValue) {
739 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
740 mapRewriteValue(eraseOp.getOpValue()));
743 void PatternLowering::generateRewriter(
744 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
745 function_ref<Value(Value)> mapRewriteValue) {
746 SmallVector<Value, 4> operands;
747 for (Value operand : operationOp.getOperandValues())
748 operands.push_back(mapRewriteValue(operand));
750 SmallVector<Value, 4> attributes;
751 for (Value attr : operationOp.getAttributeValues())
752 attributes.push_back(mapRewriteValue(attr));
754 bool hasInferredResultTypes = false;
755 SmallVector<Value, 2> types;
756 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
757 rewriteValues, hasInferredResultTypes);
759 // Create the new operation.
760 Location loc = operationOp.getLoc();
761 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
762 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
763 attributes, operationOp.getAttributeValueNames());
764 rewriteValues[operationOp.getOp()] = createdOp;
766 // Generate accesses for any results that have their types constrained.
767 // Handle the case where there is a single range representing all of the
768 // result types.
769 OperandRange resultTys = operationOp.getTypeValues();
770 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
771 Value &type = rewriteValues[resultTys[0]];
772 if (!type) {
773 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
774 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
776 return;
779 // Otherwise, populate the individual results.
780 bool seenVariableLength = false;
781 Type valueTy = builder.getType<pdl::ValueType>();
782 Type valueRangeTy = pdl::RangeType::get(valueTy);
783 for (const auto &it : llvm::enumerate(resultTys)) {
784 Value &type = rewriteValues[it.value()];
785 if (type)
786 continue;
787 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
788 seenVariableLength |= isVariadic;
790 // After a variable length result has been seen, we need to use result
791 // groups because the exact index of the result is not statically known.
792 Value resultVal;
793 if (seenVariableLength)
794 resultVal = builder.create<pdl_interp::GetResultsOp>(
795 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
796 else
797 resultVal = builder.create<pdl_interp::GetResultOp>(
798 loc, valueTy, createdOp, it.index());
799 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
803 void PatternLowering::generateRewriter(
804 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
805 function_ref<Value(Value)> mapRewriteValue) {
806 SmallVector<Value, 4> replOperands;
807 for (Value operand : rangeOp.getArguments())
808 replOperands.push_back(mapRewriteValue(operand));
809 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
810 rangeOp.getLoc(), rangeOp.getType(), replOperands);
813 void PatternLowering::generateRewriter(
814 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
815 function_ref<Value(Value)> mapRewriteValue) {
816 SmallVector<Value, 4> replOperands;
818 // If the replacement was another operation, get its results. `pdl` allows
819 // for using an operation for simplicitly, but the interpreter isn't as
820 // user facing.
821 if (Value replOp = replaceOp.getReplOperation()) {
822 // Don't use replace if we know the replaced operation has no results.
823 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
824 if (!opOp || !opOp.getTypeValues().empty()) {
825 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
826 replOp.getLoc(), mapRewriteValue(replOp)));
828 } else {
829 for (Value operand : replaceOp.getReplValues())
830 replOperands.push_back(mapRewriteValue(operand));
833 // If there are no replacement values, just create an erase instead.
834 if (replOperands.empty()) {
835 builder.create<pdl_interp::EraseOp>(
836 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
837 return;
840 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
841 mapRewriteValue(replaceOp.getOpValue()),
842 replOperands);
845 void PatternLowering::generateRewriter(
846 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
847 function_ref<Value(Value)> mapRewriteValue) {
848 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
849 resultOp.getLoc(), builder.getType<pdl::ValueType>(),
850 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
853 void PatternLowering::generateRewriter(
854 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
855 function_ref<Value(Value)> mapRewriteValue) {
856 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
857 resultOp.getLoc(), resultOp.getType(),
858 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
861 void PatternLowering::generateRewriter(
862 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
863 function_ref<Value(Value)> mapRewriteValue) {
864 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
865 // type.
866 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
867 rewriteValues[typeOp] =
868 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
872 void PatternLowering::generateRewriter(
873 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
874 function_ref<Value(Value)> mapRewriteValue) {
875 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
876 // type.
877 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
878 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
879 typeOp.getLoc(), typeOp.getType(), typeAttr);
883 void PatternLowering::generateOperationResultTypeRewriter(
884 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
885 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
886 bool &hasInferredResultTypes) {
887 Block *rewriterBlock = op->getBlock();
889 // Try to handle resolution for each of the result types individually. This is
890 // preferred over type inferrence because it will allow for us to use existing
891 // types directly, as opposed to trying to rebuild the type list.
892 OperandRange resultTypeValues = op.getTypeValues();
893 auto tryResolveResultTypes = [&] {
894 types.reserve(resultTypeValues.size());
895 for (const auto &it : llvm::enumerate(resultTypeValues)) {
896 Value resultType = it.value();
898 // Check for an already translated value.
899 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
900 types.push_back(existingRewriteValue);
901 continue;
904 // Check for an input from the matcher.
905 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
906 types.push_back(mapRewriteValue(resultType));
907 continue;
910 // Otherwise, we couldn't infer the result types. Bail out here to see if
911 // we can infer the types for this operation from another way.
912 types.clear();
913 return failure();
915 return success();
917 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
918 return;
920 // Otherwise, check if the operation has type inference support itself.
921 if (op.hasTypeInference()) {
922 hasInferredResultTypes = true;
923 return;
926 // Look for an operation that was replaced by `op`. The result types will be
927 // inferred from the results that were replaced.
928 for (OpOperand &use : op.getOp().getUses()) {
929 // Check that the use corresponds to a ReplaceOp and that it is the
930 // replacement value, not the operation being replaced.
931 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
932 if (!replOpUser || use.getOperandNumber() == 0)
933 continue;
934 // Make sure the replaced operation was defined before this one. PDL
935 // rewrites only have single block regions, so if the op isn't in the
936 // rewriter block (i.e. the current block of the operation) we already know
937 // it dominates (i.e. it's in the matcher).
938 Value replOpVal = replOpUser.getOpValue();
939 Operation *replacedOp = replOpVal.getDefiningOp();
940 if (replacedOp->getBlock() == rewriterBlock &&
941 !replacedOp->isBeforeInBlock(op))
942 continue;
944 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
945 replacedOp->getLoc(), mapRewriteValue(replOpVal));
946 types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
947 replacedOp->getLoc(), replacedOpResults));
948 return;
951 // If the types could not be inferred from any context and there weren't any
952 // explicit result types, assume the user actually meant for the operation to
953 // have no results.
954 if (resultTypeValues.empty())
955 return;
957 // The verifier asserts that the result types of each pdl.getOperation can be
958 // inferred. If we reach here, there is a bug either in the logic above or
959 // in the verifier for pdl.getOperation.
960 op->emitOpError() << "unable to infer result type for operation";
961 llvm_unreachable("unable to infer result type for operation");
964 //===----------------------------------------------------------------------===//
965 // Conversion Pass
966 //===----------------------------------------------------------------------===//
968 namespace {
969 struct PDLToPDLInterpPass
970 : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
971 PDLToPDLInterpPass() = default;
972 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
973 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
974 : configMap(&configMap) {}
975 void runOnOperation() final;
977 /// A map containing the configuration for each pattern.
978 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
980 } // namespace
982 /// Convert the given module containing PDL pattern operations into a PDL
983 /// Interpreter operations.
984 void PDLToPDLInterpPass::runOnOperation() {
985 ModuleOp module = getOperation();
987 // Create the main matcher function This function contains all of the match
988 // related functionality from patterns in the module.
989 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
990 auto matcherFunc = builder.create<pdl_interp::FuncOp>(
991 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
992 builder.getFunctionType(builder.getType<pdl::OperationType>(),
993 /*results=*/std::nullopt),
994 /*attrs=*/std::nullopt);
996 // Create a nested module to hold the functions invoked for rewriting the IR
997 // after a successful match.
998 ModuleOp rewriterModule = builder.create<ModuleOp>(
999 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
1001 // Generate the code for the patterns within the module.
1002 PatternLowering generator(matcherFunc, rewriterModule, configMap);
1003 generator.lower(module);
1005 // After generation, delete all of the pattern operations.
1006 for (pdl::PatternOp pattern :
1007 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1008 // Drop the now dead config mappings.
1009 if (configMap)
1010 configMap->erase(pattern);
1012 pattern.erase();
1016 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
1017 return std::make_unique<PDLToPDLInterpPass>();
1019 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
1020 DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
1021 return std::make_unique<PDLToPDLInterpPass>(configMap);