[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Dialect / PDL / IR / PDL.cpp
blobd5f34679f06c60b4403a6f6ecfabe678e99cd2d5
1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 #include <optional>
18 using namespace mlir;
19 using namespace mlir::pdl;
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
23 //===----------------------------------------------------------------------===//
24 // PDLDialect
25 //===----------------------------------------------------------------------===//
27 void PDLDialect::initialize() {
28 addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
31 >();
32 registerTypes();
35 //===----------------------------------------------------------------------===//
36 // PDL Operations
37 //===----------------------------------------------------------------------===//
39 /// Returns true if the given operation is used by a "binding" pdl operation.
40 static bool hasBindingUse(Operation *op) {
41 for (Operation *user : op->getUsers())
42 // A result by itself is not binding, it must also be bound.
43 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
44 return true;
45 return false;
48 /// Returns success if the given operation is not in the main matcher body or
49 /// is used by a "binding" operation. On failure, emits an error.
50 static LogicalResult verifyHasBindingUse(Operation *op) {
51 // If the parent is not a pattern, there is nothing to do.
52 if (!llvm::isa_and_nonnull<PatternOp>(op->getParentOp()))
53 return success();
54 if (hasBindingUse(op))
55 return success();
56 return op->emitOpError(
57 "expected a bindable user when defined in the matcher body of a "
58 "`pdl.pattern`");
61 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
62 /// connected to the given operation.
63 static void visit(Operation *op, DenseSet<Operation *> &visited) {
64 // If the parent is not a pattern, there is nothing to do.
65 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
66 return;
68 // Ignore if already visited.
69 if (visited.contains(op))
70 return;
72 // Mark as visited.
73 visited.insert(op);
75 // Traverse the operands / parent.
76 TypeSwitch<Operation *>(op)
77 .Case<OperationOp>([&visited](auto operation) {
78 for (Value operand : operation.getOperandValues())
79 visit(operand.getDefiningOp(), visited);
81 .Case<ResultOp, ResultsOp>([&visited](auto result) {
82 visit(result.getParent().getDefiningOp(), visited);
83 });
85 // Traverse the users.
86 for (Operation *user : op->getUsers())
87 visit(user, visited);
90 //===----------------------------------------------------------------------===//
91 // pdl::ApplyNativeConstraintOp
92 //===----------------------------------------------------------------------===//
94 LogicalResult ApplyNativeConstraintOp::verify() {
95 if (getNumOperands() == 0)
96 return emitOpError("expected at least one argument");
97 return success();
100 //===----------------------------------------------------------------------===//
101 // pdl::ApplyNativeRewriteOp
102 //===----------------------------------------------------------------------===//
104 LogicalResult ApplyNativeRewriteOp::verify() {
105 if (getNumOperands() == 0 && getNumResults() == 0)
106 return emitOpError("expected at least one argument or result");
107 return success();
110 //===----------------------------------------------------------------------===//
111 // pdl::AttributeOp
112 //===----------------------------------------------------------------------===//
114 LogicalResult AttributeOp::verify() {
115 Value attrType = getValueType();
116 std::optional<Attribute> attrValue = getValue();
118 if (!attrValue) {
119 if (isa<RewriteOp>((*this)->getParentOp()))
120 return emitOpError(
121 "expected constant value when specified within a `pdl.rewrite`");
122 return verifyHasBindingUse(*this);
124 if (attrType)
125 return emitOpError("expected only one of [`type`, `value`] to be set");
126 return success();
129 //===----------------------------------------------------------------------===//
130 // pdl::OperandOp
131 //===----------------------------------------------------------------------===//
133 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); }
135 //===----------------------------------------------------------------------===//
136 // pdl::OperandsOp
137 //===----------------------------------------------------------------------===//
139 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
141 //===----------------------------------------------------------------------===//
142 // pdl::OperationOp
143 //===----------------------------------------------------------------------===//
145 static ParseResult parseOperationOpAttributes(
146 OpAsmParser &p,
147 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
148 ArrayAttr &attrNamesAttr) {
149 Builder &builder = p.getBuilder();
150 SmallVector<Attribute, 4> attrNames;
151 if (succeeded(p.parseOptionalLBrace())) {
152 auto parseOperands = [&]() {
153 StringAttr nameAttr;
154 OpAsmParser::UnresolvedOperand operand;
155 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
156 p.parseOperand(operand))
157 return failure();
158 attrNames.push_back(nameAttr);
159 attrOperands.push_back(operand);
160 return success();
162 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
163 return failure();
165 attrNamesAttr = builder.getArrayAttr(attrNames);
166 return success();
169 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
170 OperandRange attrArgs,
171 ArrayAttr attrNames) {
172 if (attrNames.empty())
173 return;
174 p << " {";
175 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
176 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
177 p << '}';
180 /// Verifies that the result types of this operation, defined within a
181 /// `pdl.rewrite`, can be inferred.
182 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
183 OperandRange resultTypes) {
184 // Functor that returns if the given use can be used to infer a type.
185 Block *rewriterBlock = op->getBlock();
186 auto canInferTypeFromUse = [&](OpOperand &use) {
187 // If the use is within a ReplaceOp and isn't the operation being replaced
188 // (i.e. is not the first operand of the replacement), we can infer a type.
189 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
190 if (!replOpUser || use.getOperandNumber() == 0)
191 return false;
192 // Make sure the replaced operation was defined before this one.
193 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
194 return replacedOp->getBlock() != rewriterBlock ||
195 replacedOp->isBeforeInBlock(op);
198 // Check to see if the uses of the operation itself can be used to infer
199 // types.
200 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
201 return success();
203 // Handle the case where the operation has no explicit result types.
204 if (resultTypes.empty()) {
205 // If we don't know the concrete operation, don't attempt any verification.
206 // We can't make assumptions if we don't know the concrete operation.
207 std::optional<StringRef> rawOpName = op.getOpName();
208 if (!rawOpName)
209 return success();
210 std::optional<RegisteredOperationName> opName =
211 RegisteredOperationName::lookup(*rawOpName, op.getContext());
212 if (!opName)
213 return success();
215 // If no explicit result types were provided, check to see if the operation
216 // expected at least one result. This doesn't cover all cases, but this
217 // should cover many cases in which the user intended to infer the results
218 // of an operation, but it isn't actually possible.
219 bool expectedAtLeastOneResult =
220 !opName->hasTrait<OpTrait::ZeroResults>() &&
221 !opName->hasTrait<OpTrait::VariadicResults>();
222 if (expectedAtLeastOneResult) {
223 return op
224 .emitOpError("must have inferable or constrained result types when "
225 "nested within `pdl.rewrite`")
226 .attachNote()
227 .append("operation is created in a non-inferrable context, but '",
228 *opName, "' does not implement InferTypeOpInterface");
230 return success();
233 // Otherwise, make sure each of the types can be inferred.
234 for (const auto &it : llvm::enumerate(resultTypes)) {
235 Operation *resultTypeOp = it.value().getDefiningOp();
236 assert(resultTypeOp && "expected valid result type operation");
238 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
239 // usable.
240 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
241 continue;
243 // If the type operation was defined in the matcher and constrains an
244 // operand or the result of an input operation, it can be used.
245 auto constrainsInput = [rewriterBlock](Operation *user) {
246 return user->getBlock() != rewriterBlock &&
247 isa<OperandOp, OperandsOp, OperationOp>(user);
249 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
250 if (typeOp.getConstantType() ||
251 llvm::any_of(typeOp->getUsers(), constrainsInput))
252 continue;
253 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
254 if (typeOp.getConstantTypes() ||
255 llvm::any_of(typeOp->getUsers(), constrainsInput))
256 continue;
259 return op
260 .emitOpError("must have inferable or constrained result types when "
261 "nested within `pdl.rewrite`")
262 .attachNote()
263 .append("result type #", it.index(), " was not constrained");
265 return success();
268 LogicalResult OperationOp::verify() {
269 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
270 if (isWithinRewrite && !getOpName())
271 return emitOpError("must have an operation name when nested within "
272 "a `pdl.rewrite`");
273 ArrayAttr attributeNames = getAttributeValueNamesAttr();
274 auto attributeValues = getAttributeValues();
275 if (attributeNames.size() != attributeValues.size()) {
276 return emitOpError()
277 << "expected the same number of attribute values and attribute "
278 "names, got "
279 << attributeNames.size() << " names and " << attributeValues.size()
280 << " values";
283 // If the operation is within a rewrite body and doesn't have type inference,
284 // ensure that the result types can be resolved.
285 if (isWithinRewrite && !mightHaveTypeInference()) {
286 if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
287 return failure();
290 return verifyHasBindingUse(*this);
293 bool OperationOp::hasTypeInference() {
294 if (std::optional<StringRef> rawOpName = getOpName()) {
295 OperationName opName(*rawOpName, getContext());
296 return opName.hasInterface<InferTypeOpInterface>();
298 return false;
301 bool OperationOp::mightHaveTypeInference() {
302 if (std::optional<StringRef> rawOpName = getOpName()) {
303 OperationName opName(*rawOpName, getContext());
304 return opName.mightHaveInterface<InferTypeOpInterface>();
306 return false;
309 //===----------------------------------------------------------------------===//
310 // pdl::PatternOp
311 //===----------------------------------------------------------------------===//
313 LogicalResult PatternOp::verifyRegions() {
314 Region &body = getBodyRegion();
315 Operation *term = body.front().getTerminator();
316 auto rewriteOp = dyn_cast<RewriteOp>(term);
317 if (!rewriteOp) {
318 return emitOpError("expected body to terminate with `pdl.rewrite`")
319 .attachNote(term->getLoc())
320 .append("see terminator defined here");
323 // Check that all values defined in the top-level pattern belong to the PDL
324 // dialect.
325 WalkResult result = body.walk([&](Operation *op) -> WalkResult {
326 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
327 emitOpError("expected only `pdl` operations within the pattern body")
328 .attachNote(op->getLoc())
329 .append("see non-`pdl` operation defined here");
330 return WalkResult::interrupt();
332 return WalkResult::advance();
334 if (result.wasInterrupted())
335 return failure();
337 // Check that there is at least one operation.
338 if (body.front().getOps<OperationOp>().empty())
339 return emitOpError("the pattern must contain at least one `pdl.operation`");
341 // Determine if the operations within the pdl.pattern form a connected
342 // component. This is determined by starting the search from the first
343 // operand/result/operation and visiting their users / parents / operands.
344 // We limit our attention to operations that have a user in pdl.rewrite,
345 // those that do not will be detected via other means (expected bindable
346 // user).
347 bool first = true;
348 DenseSet<Operation *> visited;
349 for (Operation &op : body.front()) {
350 // The following are the operations forming the connected component.
351 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
352 continue;
354 // Determine if the operation has a user in `pdl.rewrite`.
355 bool hasUserInRewrite = false;
356 for (Operation *user : op.getUsers()) {
357 Region *region = user->getParentRegion();
358 if (isa<RewriteOp>(user) ||
359 (region && isa<RewriteOp>(region->getParentOp()))) {
360 hasUserInRewrite = true;
361 break;
365 // If the operation does not have a user in `pdl.rewrite`, ignore it.
366 if (!hasUserInRewrite)
367 continue;
369 if (first) {
370 // For the first operation, invoke visit.
371 visit(&op, visited);
372 first = false;
373 } else if (!visited.count(&op)) {
374 // For the subsequent operations, check if already visited.
375 return emitOpError("the operations must form a connected component")
376 .attachNote(op.getLoc())
377 .append("see a disconnected value / operation here");
381 return success();
384 void PatternOp::build(OpBuilder &builder, OperationState &state,
385 std::optional<uint16_t> benefit,
386 std::optional<StringRef> name) {
387 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0),
388 name ? builder.getStringAttr(*name) : StringAttr());
389 state.regions[0]->emplaceBlock();
392 /// Returns the rewrite operation of this pattern.
393 RewriteOp PatternOp::getRewriter() {
394 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
397 /// The default dialect is `pdl`.
398 StringRef PatternOp::getDefaultDialect() {
399 return PDLDialect::getDialectNamespace();
402 //===----------------------------------------------------------------------===//
403 // pdl::RangeOp
404 //===----------------------------------------------------------------------===//
406 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
407 Type &resultType) {
408 // If arguments were provided, infer the result type from the argument list.
409 if (!argumentTypes.empty()) {
410 resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
411 return success();
413 // Otherwise, parse the type as a trailing type.
414 return p.parseColonType(resultType);
417 static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
418 Type resultType) {
419 if (argumentTypes.empty())
420 p << ": " << resultType;
423 LogicalResult RangeOp::verify() {
424 Type elementType = getType().getElementType();
425 for (Type operandType : getOperandTypes()) {
426 Type operandElementType = getRangeElementTypeOrSelf(operandType);
427 if (operandElementType != elementType) {
428 return emitOpError("expected operand to have element type ")
429 << elementType << ", but got " << operandElementType;
432 return success();
435 //===----------------------------------------------------------------------===//
436 // pdl::ReplaceOp
437 //===----------------------------------------------------------------------===//
439 LogicalResult ReplaceOp::verify() {
440 if (getReplOperation() && !getReplValues().empty())
441 return emitOpError() << "expected no replacement values to be provided"
442 " when the replacement operation is present";
443 return success();
446 //===----------------------------------------------------------------------===//
447 // pdl::ResultsOp
448 //===----------------------------------------------------------------------===//
450 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
451 Type &resultType) {
452 if (!index) {
453 resultType = RangeType::get(p.getBuilder().getType<ValueType>());
454 return success();
456 if (p.parseArrow() || p.parseType(resultType))
457 return failure();
458 return success();
461 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
462 IntegerAttr index, Type resultType) {
463 if (index)
464 p << " -> " << resultType;
467 LogicalResult ResultsOp::verify() {
468 if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
469 return emitOpError() << "expected `pdl.range<value>` result type when "
470 "no index is specified, but got: "
471 << getType();
473 return success();
476 //===----------------------------------------------------------------------===//
477 // pdl::RewriteOp
478 //===----------------------------------------------------------------------===//
480 LogicalResult RewriteOp::verifyRegions() {
481 Region &rewriteRegion = getBodyRegion();
483 // Handle the case where the rewrite is external.
484 if (getName()) {
485 if (!rewriteRegion.empty()) {
486 return emitOpError()
487 << "expected rewrite region to be empty when rewrite is external";
489 return success();
492 // Otherwise, check that the rewrite region only contains a single block.
493 if (rewriteRegion.empty()) {
494 return emitOpError() << "expected rewrite region to be non-empty if "
495 "external name is not specified";
498 // Check that no additional arguments were provided.
499 if (!getExternalArgs().empty()) {
500 return emitOpError() << "expected no external arguments when the "
501 "rewrite is specified inline";
504 return success();
507 /// The default dialect is `pdl`.
508 StringRef RewriteOp::getDefaultDialect() {
509 return PDLDialect::getDialectNamespace();
512 //===----------------------------------------------------------------------===//
513 // pdl::TypeOp
514 //===----------------------------------------------------------------------===//
516 LogicalResult TypeOp::verify() {
517 if (!getConstantTypeAttr())
518 return verifyHasBindingUse(*this);
519 return success();
522 //===----------------------------------------------------------------------===//
523 // pdl::TypesOp
524 //===----------------------------------------------------------------------===//
526 LogicalResult TypesOp::verify() {
527 if (!getConstantTypesAttr())
528 return verifyHasBindingUse(*this);
529 return success();
532 //===----------------------------------------------------------------------===//
533 // TableGen'd op method definitions
534 //===----------------------------------------------------------------------===//
536 #define GET_OP_CLASSES
537 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"