1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
19 using namespace mlir::pdl
;
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
23 //===----------------------------------------------------------------------===//
25 //===----------------------------------------------------------------------===//
27 void PDLDialect::initialize() {
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
35 //===----------------------------------------------------------------------===//
37 //===----------------------------------------------------------------------===//
39 /// Returns true if the given operation is used by a "binding" pdl operation.
40 static bool hasBindingUse(Operation
*op
) {
41 for (Operation
*user
: op
->getUsers())
42 // A result by itself is not binding, it must also be bound.
43 if (!isa
<ResultOp
, ResultsOp
>(user
) || hasBindingUse(user
))
48 /// Returns success if the given operation is not in the main matcher body or
49 /// is used by a "binding" operation. On failure, emits an error.
50 static LogicalResult
verifyHasBindingUse(Operation
*op
) {
51 // If the parent is not a pattern, there is nothing to do.
52 if (!llvm::isa_and_nonnull
<PatternOp
>(op
->getParentOp()))
54 if (hasBindingUse(op
))
56 return op
->emitOpError(
57 "expected a bindable user when defined in the matcher body of a "
61 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
62 /// connected to the given operation.
63 static void visit(Operation
*op
, DenseSet
<Operation
*> &visited
) {
64 // If the parent is not a pattern, there is nothing to do.
65 if (!isa
<PatternOp
>(op
->getParentOp()) || isa
<RewriteOp
>(op
))
68 // Ignore if already visited.
69 if (visited
.contains(op
))
75 // Traverse the operands / parent.
76 TypeSwitch
<Operation
*>(op
)
77 .Case
<OperationOp
>([&visited
](auto operation
) {
78 for (Value operand
: operation
.getOperandValues())
79 visit(operand
.getDefiningOp(), visited
);
81 .Case
<ResultOp
, ResultsOp
>([&visited
](auto result
) {
82 visit(result
.getParent().getDefiningOp(), visited
);
85 // Traverse the users.
86 for (Operation
*user
: op
->getUsers())
90 //===----------------------------------------------------------------------===//
91 // pdl::ApplyNativeConstraintOp
92 //===----------------------------------------------------------------------===//
94 LogicalResult
ApplyNativeConstraintOp::verify() {
95 if (getNumOperands() == 0)
96 return emitOpError("expected at least one argument");
100 //===----------------------------------------------------------------------===//
101 // pdl::ApplyNativeRewriteOp
102 //===----------------------------------------------------------------------===//
104 LogicalResult
ApplyNativeRewriteOp::verify() {
105 if (getNumOperands() == 0 && getNumResults() == 0)
106 return emitOpError("expected at least one argument or result");
110 //===----------------------------------------------------------------------===//
112 //===----------------------------------------------------------------------===//
114 LogicalResult
AttributeOp::verify() {
115 Value attrType
= getValueType();
116 std::optional
<Attribute
> attrValue
= getValue();
119 if (isa
<RewriteOp
>((*this)->getParentOp()))
121 "expected constant value when specified within a `pdl.rewrite`");
122 return verifyHasBindingUse(*this);
125 return emitOpError("expected only one of [`type`, `value`] to be set");
129 //===----------------------------------------------------------------------===//
131 //===----------------------------------------------------------------------===//
133 LogicalResult
OperandOp::verify() { return verifyHasBindingUse(*this); }
135 //===----------------------------------------------------------------------===//
137 //===----------------------------------------------------------------------===//
139 LogicalResult
OperandsOp::verify() { return verifyHasBindingUse(*this); }
141 //===----------------------------------------------------------------------===//
143 //===----------------------------------------------------------------------===//
145 static ParseResult
parseOperationOpAttributes(
147 SmallVectorImpl
<OpAsmParser::UnresolvedOperand
> &attrOperands
,
148 ArrayAttr
&attrNamesAttr
) {
149 Builder
&builder
= p
.getBuilder();
150 SmallVector
<Attribute
, 4> attrNames
;
151 if (succeeded(p
.parseOptionalLBrace())) {
152 auto parseOperands
= [&]() {
154 OpAsmParser::UnresolvedOperand operand
;
155 if (p
.parseAttribute(nameAttr
) || p
.parseEqual() ||
156 p
.parseOperand(operand
))
158 attrNames
.push_back(nameAttr
);
159 attrOperands
.push_back(operand
);
162 if (p
.parseCommaSeparatedList(parseOperands
) || p
.parseRBrace())
165 attrNamesAttr
= builder
.getArrayAttr(attrNames
);
169 static void printOperationOpAttributes(OpAsmPrinter
&p
, OperationOp op
,
170 OperandRange attrArgs
,
171 ArrayAttr attrNames
) {
172 if (attrNames
.empty())
175 interleaveComma(llvm::seq
<int>(0, attrNames
.size()), p
,
176 [&](int i
) { p
<< attrNames
[i
] << " = " << attrArgs
[i
]; });
180 /// Verifies that the result types of this operation, defined within a
181 /// `pdl.rewrite`, can be inferred.
182 static LogicalResult
verifyResultTypesAreInferrable(OperationOp op
,
183 OperandRange resultTypes
) {
184 // Functor that returns if the given use can be used to infer a type.
185 Block
*rewriterBlock
= op
->getBlock();
186 auto canInferTypeFromUse
= [&](OpOperand
&use
) {
187 // If the use is within a ReplaceOp and isn't the operation being replaced
188 // (i.e. is not the first operand of the replacement), we can infer a type.
189 ReplaceOp replOpUser
= dyn_cast
<ReplaceOp
>(use
.getOwner());
190 if (!replOpUser
|| use
.getOperandNumber() == 0)
192 // Make sure the replaced operation was defined before this one.
193 Operation
*replacedOp
= replOpUser
.getOpValue().getDefiningOp();
194 return replacedOp
->getBlock() != rewriterBlock
||
195 replacedOp
->isBeforeInBlock(op
);
198 // Check to see if the uses of the operation itself can be used to infer
200 if (llvm::any_of(op
.getOp().getUses(), canInferTypeFromUse
))
203 // Handle the case where the operation has no explicit result types.
204 if (resultTypes
.empty()) {
205 // If we don't know the concrete operation, don't attempt any verification.
206 // We can't make assumptions if we don't know the concrete operation.
207 std::optional
<StringRef
> rawOpName
= op
.getOpName();
210 std::optional
<RegisteredOperationName
> opName
=
211 RegisteredOperationName::lookup(*rawOpName
, op
.getContext());
215 // If no explicit result types were provided, check to see if the operation
216 // expected at least one result. This doesn't cover all cases, but this
217 // should cover many cases in which the user intended to infer the results
218 // of an operation, but it isn't actually possible.
219 bool expectedAtLeastOneResult
=
220 !opName
->hasTrait
<OpTrait::ZeroResults
>() &&
221 !opName
->hasTrait
<OpTrait::VariadicResults
>();
222 if (expectedAtLeastOneResult
) {
224 .emitOpError("must have inferable or constrained result types when "
225 "nested within `pdl.rewrite`")
227 .append("operation is created in a non-inferrable context, but '",
228 *opName
, "' does not implement InferTypeOpInterface");
233 // Otherwise, make sure each of the types can be inferred.
234 for (const auto &it
: llvm::enumerate(resultTypes
)) {
235 Operation
*resultTypeOp
= it
.value().getDefiningOp();
236 assert(resultTypeOp
&& "expected valid result type operation");
238 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
240 if (isa
<ApplyNativeRewriteOp
>(resultTypeOp
))
243 // If the type operation was defined in the matcher and constrains an
244 // operand or the result of an input operation, it can be used.
245 auto constrainsInput
= [rewriterBlock
](Operation
*user
) {
246 return user
->getBlock() != rewriterBlock
&&
247 isa
<OperandOp
, OperandsOp
, OperationOp
>(user
);
249 if (TypeOp typeOp
= dyn_cast
<TypeOp
>(resultTypeOp
)) {
250 if (typeOp
.getConstantType() ||
251 llvm::any_of(typeOp
->getUsers(), constrainsInput
))
253 } else if (TypesOp typeOp
= dyn_cast
<TypesOp
>(resultTypeOp
)) {
254 if (typeOp
.getConstantTypes() ||
255 llvm::any_of(typeOp
->getUsers(), constrainsInput
))
260 .emitOpError("must have inferable or constrained result types when "
261 "nested within `pdl.rewrite`")
263 .append("result type #", it
.index(), " was not constrained");
268 LogicalResult
OperationOp::verify() {
269 bool isWithinRewrite
= isa_and_nonnull
<RewriteOp
>((*this)->getParentOp());
270 if (isWithinRewrite
&& !getOpName())
271 return emitOpError("must have an operation name when nested within "
273 ArrayAttr attributeNames
= getAttributeValueNamesAttr();
274 auto attributeValues
= getAttributeValues();
275 if (attributeNames
.size() != attributeValues
.size()) {
277 << "expected the same number of attribute values and attribute "
279 << attributeNames
.size() << " names and " << attributeValues
.size()
283 // If the operation is within a rewrite body and doesn't have type inference,
284 // ensure that the result types can be resolved.
285 if (isWithinRewrite
&& !mightHaveTypeInference()) {
286 if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
290 return verifyHasBindingUse(*this);
293 bool OperationOp::hasTypeInference() {
294 if (std::optional
<StringRef
> rawOpName
= getOpName()) {
295 OperationName
opName(*rawOpName
, getContext());
296 return opName
.hasInterface
<InferTypeOpInterface
>();
301 bool OperationOp::mightHaveTypeInference() {
302 if (std::optional
<StringRef
> rawOpName
= getOpName()) {
303 OperationName
opName(*rawOpName
, getContext());
304 return opName
.mightHaveInterface
<InferTypeOpInterface
>();
309 //===----------------------------------------------------------------------===//
311 //===----------------------------------------------------------------------===//
313 LogicalResult
PatternOp::verifyRegions() {
314 Region
&body
= getBodyRegion();
315 Operation
*term
= body
.front().getTerminator();
316 auto rewriteOp
= dyn_cast
<RewriteOp
>(term
);
318 return emitOpError("expected body to terminate with `pdl.rewrite`")
319 .attachNote(term
->getLoc())
320 .append("see terminator defined here");
323 // Check that all values defined in the top-level pattern belong to the PDL
325 WalkResult result
= body
.walk([&](Operation
*op
) -> WalkResult
{
326 if (!isa_and_nonnull
<PDLDialect
>(op
->getDialect())) {
327 emitOpError("expected only `pdl` operations within the pattern body")
328 .attachNote(op
->getLoc())
329 .append("see non-`pdl` operation defined here");
330 return WalkResult::interrupt();
332 return WalkResult::advance();
334 if (result
.wasInterrupted())
337 // Check that there is at least one operation.
338 if (body
.front().getOps
<OperationOp
>().empty())
339 return emitOpError("the pattern must contain at least one `pdl.operation`");
341 // Determine if the operations within the pdl.pattern form a connected
342 // component. This is determined by starting the search from the first
343 // operand/result/operation and visiting their users / parents / operands.
344 // We limit our attention to operations that have a user in pdl.rewrite,
345 // those that do not will be detected via other means (expected bindable
348 DenseSet
<Operation
*> visited
;
349 for (Operation
&op
: body
.front()) {
350 // The following are the operations forming the connected component.
351 if (!isa
<OperandOp
, OperandsOp
, ResultOp
, ResultsOp
, OperationOp
>(op
))
354 // Determine if the operation has a user in `pdl.rewrite`.
355 bool hasUserInRewrite
= false;
356 for (Operation
*user
: op
.getUsers()) {
357 Region
*region
= user
->getParentRegion();
358 if (isa
<RewriteOp
>(user
) ||
359 (region
&& isa
<RewriteOp
>(region
->getParentOp()))) {
360 hasUserInRewrite
= true;
365 // If the operation does not have a user in `pdl.rewrite`, ignore it.
366 if (!hasUserInRewrite
)
370 // For the first operation, invoke visit.
373 } else if (!visited
.count(&op
)) {
374 // For the subsequent operations, check if already visited.
375 return emitOpError("the operations must form a connected component")
376 .attachNote(op
.getLoc())
377 .append("see a disconnected value / operation here");
384 void PatternOp::build(OpBuilder
&builder
, OperationState
&state
,
385 std::optional
<uint16_t> benefit
,
386 std::optional
<StringRef
> name
) {
387 build(builder
, state
, builder
.getI16IntegerAttr(benefit
? *benefit
: 0),
388 name
? builder
.getStringAttr(*name
) : StringAttr());
389 state
.regions
[0]->emplaceBlock();
392 /// Returns the rewrite operation of this pattern.
393 RewriteOp
PatternOp::getRewriter() {
394 return cast
<RewriteOp
>(getBodyRegion().front().getTerminator());
397 /// The default dialect is `pdl`.
398 StringRef
PatternOp::getDefaultDialect() {
399 return PDLDialect::getDialectNamespace();
402 //===----------------------------------------------------------------------===//
404 //===----------------------------------------------------------------------===//
406 static ParseResult
parseRangeType(OpAsmParser
&p
, TypeRange argumentTypes
,
408 // If arguments were provided, infer the result type from the argument list.
409 if (!argumentTypes
.empty()) {
410 resultType
= RangeType::get(getRangeElementTypeOrSelf(argumentTypes
[0]));
413 // Otherwise, parse the type as a trailing type.
414 return p
.parseColonType(resultType
);
417 static void printRangeType(OpAsmPrinter
&p
, RangeOp op
, TypeRange argumentTypes
,
419 if (argumentTypes
.empty())
420 p
<< ": " << resultType
;
423 LogicalResult
RangeOp::verify() {
424 Type elementType
= getType().getElementType();
425 for (Type operandType
: getOperandTypes()) {
426 Type operandElementType
= getRangeElementTypeOrSelf(operandType
);
427 if (operandElementType
!= elementType
) {
428 return emitOpError("expected operand to have element type ")
429 << elementType
<< ", but got " << operandElementType
;
435 //===----------------------------------------------------------------------===//
437 //===----------------------------------------------------------------------===//
439 LogicalResult
ReplaceOp::verify() {
440 if (getReplOperation() && !getReplValues().empty())
441 return emitOpError() << "expected no replacement values to be provided"
442 " when the replacement operation is present";
446 //===----------------------------------------------------------------------===//
448 //===----------------------------------------------------------------------===//
450 static ParseResult
parseResultsValueType(OpAsmParser
&p
, IntegerAttr index
,
453 resultType
= RangeType::get(p
.getBuilder().getType
<ValueType
>());
456 if (p
.parseArrow() || p
.parseType(resultType
))
461 static void printResultsValueType(OpAsmPrinter
&p
, ResultsOp op
,
462 IntegerAttr index
, Type resultType
) {
464 p
<< " -> " << resultType
;
467 LogicalResult
ResultsOp::verify() {
468 if (!getIndex() && llvm::isa
<pdl::ValueType
>(getType())) {
469 return emitOpError() << "expected `pdl.range<value>` result type when "
470 "no index is specified, but got: "
476 //===----------------------------------------------------------------------===//
478 //===----------------------------------------------------------------------===//
480 LogicalResult
RewriteOp::verifyRegions() {
481 Region
&rewriteRegion
= getBodyRegion();
483 // Handle the case where the rewrite is external.
485 if (!rewriteRegion
.empty()) {
487 << "expected rewrite region to be empty when rewrite is external";
492 // Otherwise, check that the rewrite region only contains a single block.
493 if (rewriteRegion
.empty()) {
494 return emitOpError() << "expected rewrite region to be non-empty if "
495 "external name is not specified";
498 // Check that no additional arguments were provided.
499 if (!getExternalArgs().empty()) {
500 return emitOpError() << "expected no external arguments when the "
501 "rewrite is specified inline";
507 /// The default dialect is `pdl`.
508 StringRef
RewriteOp::getDefaultDialect() {
509 return PDLDialect::getDialectNamespace();
512 //===----------------------------------------------------------------------===//
514 //===----------------------------------------------------------------------===//
516 LogicalResult
TypeOp::verify() {
517 if (!getConstantTypeAttr())
518 return verifyHasBindingUse(*this);
522 //===----------------------------------------------------------------------===//
524 //===----------------------------------------------------------------------===//
526 LogicalResult
TypesOp::verify() {
527 if (!getConstantTypesAttr())
528 return verifyHasBindingUse(*this);
532 //===----------------------------------------------------------------------===//
533 // TableGen'd op method definitions
534 //===----------------------------------------------------------------------===//
536 #define GET_OP_CLASSES
537 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"