[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / RewriterGen.cpp
blob2c79ba2cd6353eef6d801459190242e05d16fe51
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Argument.h"
15 #include "mlir/TableGen/Attribute.h"
16 #include "mlir/TableGen/CodeGenHelpers.h"
17 #include "mlir/TableGen/Format.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "mlir/TableGen/Operator.h"
20 #include "mlir/TableGen/Pattern.h"
21 #include "mlir/TableGen/Predicate.h"
22 #include "mlir/TableGen/Property.h"
23 #include "mlir/TableGen/Type.h"
24 #include "llvm/ADT/FunctionExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/FormatAdapters.h"
31 #include "llvm/Support/PrettyStackTrace.h"
32 #include "llvm/Support/Signals.h"
33 #include "llvm/TableGen/Error.h"
34 #include "llvm/TableGen/Main.h"
35 #include "llvm/TableGen/Record.h"
36 #include "llvm/TableGen/TableGenBackend.h"
38 using namespace mlir;
39 using namespace mlir::tblgen;
41 using llvm::formatv;
42 using llvm::Record;
43 using llvm::RecordKeeper;
45 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
47 namespace llvm {
48 template <>
49 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
50 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
51 raw_ostream &os, StringRef style) {
52 os << v.first << ":" << v.second;
55 } // namespace llvm
57 //===----------------------------------------------------------------------===//
58 // PatternEmitter
59 //===----------------------------------------------------------------------===//
61 namespace {
63 class StaticMatcherHelper;
65 class PatternEmitter {
66 public:
67 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
68 StaticMatcherHelper &helper);
70 // Emits the mlir::RewritePattern struct named `rewriteName`.
71 void emit(StringRef rewriteName);
73 // Emits the static function of DAG matcher.
74 void emitStaticMatcher(DagNode tree, std::string funcName);
76 private:
77 // Emits the code for matching ops.
78 void emitMatchLogic(DagNode tree, StringRef opName);
80 // Emits the code for rewriting ops.
81 void emitRewriteLogic();
83 //===--------------------------------------------------------------------===//
84 // Match utilities
85 //===--------------------------------------------------------------------===//
87 // Emits C++ statements for matching the DAG structure.
88 void emitMatch(DagNode tree, StringRef name, int depth);
90 // Emit C++ function call to static DAG matcher.
91 void emitStaticMatchCall(DagNode tree, StringRef name);
93 // Emit C++ function call to static type/attribute constraint function.
94 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
95 StringRef arg, StringRef failureStr);
97 // Emits C++ statements for matching using a native code call.
98 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
100 // Emits C++ statements for matching the op constrained by the given DAG
101 // `tree` returning the op's variable name.
102 void emitOpMatch(DagNode tree, StringRef opName, int depth);
104 // Emits C++ statements for matching the `argIndex`-th argument of the given
105 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
106 // bound name and the constraint of the operand respectively.
107 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
108 int operandIndex, DagLeaf operandMatcher,
109 StringRef argName, int argIndex,
110 std::optional<int> variadicSubIndex);
112 // Emits C++ statements for matching the operands which can be matched in
113 // either order.
114 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
115 StringRef opName, int argIndex, int &operandIndex,
116 int depth);
118 // Emits C++ statements for matching a variadic operand.
119 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
120 StringRef opName, int argIndex,
121 int &operandIndex, int depth);
123 // Emits C++ statements for matching the `argIndex`-th argument of the given
124 // DAG `tree` as an attribute.
125 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
126 int depth);
128 // Emits C++ for checking a match with a corresponding match failure
129 // diagnostic.
130 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
131 const llvm::formatv_object_base &failureFmt);
133 // Emits C++ for checking a match with a corresponding match failure
134 // diagnostics.
135 void emitMatchCheck(StringRef opName, const std::string &matchStr,
136 const std::string &failureStr);
138 //===--------------------------------------------------------------------===//
139 // Rewrite utilities
140 //===--------------------------------------------------------------------===//
142 // The entry point for handling a result pattern rooted at `resultTree`. This
143 // method dispatches to concrete handlers according to `resultTree`'s kind and
144 // returns a symbol representing the whole value pack. Callers are expected to
145 // further resolve the symbol according to the specific use case.
147 // `depth` is the nesting level of `resultTree`; 0 means top-level result
148 // pattern. For top-level result pattern, `resultIndex` indicates which result
149 // of the matched root op this pattern is intended to replace, which can be
150 // used to deduce the result type of the op generated from this result
151 // pattern.
152 std::string handleResultPattern(DagNode resultTree, int resultIndex,
153 int depth);
155 // Emits the C++ statement to replace the matched DAG with a value built via
156 // calling native C++ code.
157 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
159 // Returns the symbol of the old value serving as the replacement.
160 StringRef handleReplaceWithValue(DagNode tree);
162 // Emits the C++ statement to replace the matched DAG with an array of
163 // matched values.
164 std::string handleVariadic(DagNode tree, int depth);
166 // Trailing directives are used at the end of DAG node argument lists to
167 // specify additional behaviour for op matchers and creators, etc.
168 struct TrailingDirectives {
169 // DAG node containing the `location` directive. Null if there is none.
170 DagNode location;
172 // DAG node containing the `returnType` directive. Null if there is none.
173 DagNode returnType;
175 // Number of found trailing directives.
176 int numDirectives;
179 // Collect any trailing directives.
180 TrailingDirectives getTrailingDirectives(DagNode tree);
182 // Returns the location value to use.
183 std::string getLocation(TrailingDirectives &tail);
185 // Returns the location value to use.
186 std::string handleLocationDirective(DagNode tree);
188 // Emit return type argument.
189 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
191 // Emits the C++ statement to build a new op out of the given DAG `tree` and
192 // returns the variable name that this op is assigned to. If the root op in
193 // DAG `tree` has a specified name, the created op will be assigned to a
194 // variable of the given name. Otherwise, a unique name will be used as the
195 // result value name.
196 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
198 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
200 // Emits a local variable for each value and attribute to be used for creating
201 // an op.
202 void createSeparateLocalVarsForOpArgs(DagNode node,
203 ChildNodeIndexNameMap &childNodeNames);
205 // Emits the concrete arguments used to call an op's builder.
206 void supplyValuesForOpArgs(DagNode node,
207 const ChildNodeIndexNameMap &childNodeNames,
208 int depth);
210 // Emits the local variables for holding all values as a whole and all named
211 // attributes as a whole to be used for creating an op.
212 void createAggregateLocalVarsForOpArgs(
213 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
215 // Returns the C++ expression to construct a constant attribute of the given
216 // `value` for the given attribute kind `attr`.
217 std::string handleConstantAttr(Attribute attr, const Twine &value);
219 // Returns the C++ expression to build an argument from the given DAG `leaf`.
220 // `patArgName` is used to bound the argument to the source pattern.
221 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
223 //===--------------------------------------------------------------------===//
224 // General utilities
225 //===--------------------------------------------------------------------===//
227 // Collects all of the operations within the given dag tree.
228 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
230 // Returns a unique symbol for a local variable of the given `op`.
231 std::string getUniqueSymbol(const Operator *op);
233 //===--------------------------------------------------------------------===//
234 // Symbol utilities
235 //===--------------------------------------------------------------------===//
237 // Returns how many static values the given DAG `node` correspond to.
238 int getNodeValueCount(DagNode node);
240 private:
241 // Pattern instantiation location followed by the location of multiclass
242 // prototypes used. This is intended to be used as a whole to
243 // PrintFatalError() on errors.
244 ArrayRef<SMLoc> loc;
246 // Op's TableGen Record to wrapper object.
247 RecordOperatorMap *opMap;
249 // Handy wrapper for pattern being emitted.
250 Pattern pattern;
252 // Map for all bound symbols' info.
253 SymbolInfoMap symbolInfoMap;
255 StaticMatcherHelper &staticMatcherHelper;
257 // The next unused ID for newly created values.
258 unsigned nextValueId = 0;
260 raw_indented_ostream os;
262 // Format contexts containing placeholder substitutions.
263 FmtContext fmtCtx;
266 // Tracks DagNode's reference multiple times across patterns. Enables generating
267 // static matcher functions for DagNode's referenced multiple times rather than
268 // inlining them.
269 class StaticMatcherHelper {
270 public:
271 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
272 RecordOperatorMap &mapper);
274 // Determine if we should inline the match logic or delegate to a static
275 // function.
276 bool useStaticMatcher(DagNode node) {
277 // either/variadic node must be associated to the parentOp, thus we can't
278 // emit a static matcher rooted at them.
279 if (node.isEither() || node.isVariadic())
280 return false;
282 return refStats[node] > kStaticMatcherThreshold;
285 // Get the name of the static DAG matcher function corresponding to the node.
286 std::string getMatcherName(DagNode node) {
287 assert(useStaticMatcher(node));
288 return matcherNames[node];
291 // Get the name of static type/attribute verification function.
292 StringRef getVerifierName(DagLeaf leaf);
294 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
295 // the duplicated DAGs.
296 void addPattern(Record *record);
298 // Emit all static functions of DAG Matcher.
299 void populateStaticMatchers(raw_ostream &os);
301 // Emit all static functions for Constraints.
302 void populateStaticConstraintFunctions(raw_ostream &os);
304 private:
305 static constexpr unsigned kStaticMatcherThreshold = 1;
307 // Consider two patterns as down below,
308 // DagNode_Root_A DagNode_Root_B
309 // \ \
310 // DagNode_C DagNode_C
311 // \ \
312 // DagNode_D DagNode_D
314 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
315 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
316 // multiple times so we'll have static matchers for both of them. When we're
317 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
318 // static matcher generated. If so, then we'll generate a call to the
319 // function, inline otherwise. In this case, inlining is not what we want. As
320 // a result, generate the static matcher in topological order to ensure all
321 // the dependent static matchers are generated and we can avoid accidentally
322 // inlining.
324 // The topological order of all the DagNodes among all patterns.
325 SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
327 RecordOperatorMap &opMap;
329 // Records of the static function name of each DagNode
330 DenseMap<DagNode, std::string> matcherNames;
332 // After collecting all the DagNode in each pattern, `refStats` records the
333 // number of users for each DagNode. We will generate the static matcher for a
334 // DagNode while the number of users exceeds a certain threshold.
335 DenseMap<DagNode, unsigned> refStats;
337 // Number of static matcher generated. This is used to generate a unique name
338 // for each DagNode.
339 int staticMatcherCounter = 0;
341 // The DagLeaf which contains type or attr constraint.
342 SetVector<DagLeaf> constraints;
344 // Static type/attribute verification function emitter.
345 StaticVerifierFunctionEmitter staticVerifierEmitter;
348 } // namespace
350 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
351 raw_ostream &os, StaticMatcherHelper &helper)
352 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
353 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
354 fmtCtx.withBuilder("rewriter");
357 std::string PatternEmitter::handleConstantAttr(Attribute attr,
358 const Twine &value) {
359 if (!attr.isConstBuildable())
360 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
361 " does not have the 'constBuilderCall' field");
363 // TODO: Verify the constants here
364 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
367 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
368 os << formatv(
369 "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
370 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
371 "*, 4> &tblgen_ops",
372 funcName);
374 // We pass the reference of the variables that need to be captured. Hence we
375 // need to collect all the symbols in the tree first.
376 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
377 symbolInfoMap.assignUniqueAlternativeNames();
378 for (const auto &info : symbolInfoMap)
379 os << formatv(", {0}", info.second.getArgDecl(info.first));
381 os << ") {\n";
382 os.indent();
383 os << "(void)tblgen_ops;\n";
385 // Note that a static matcher is considered at least one step from the match
386 // entry.
387 emitMatch(tree, "op0", /*depth=*/1);
389 os << "return ::mlir::success();\n";
390 os.unindent();
391 os << "}\n\n";
394 // Helper function to match patterns.
395 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
396 if (tree.isNativeCodeCall()) {
397 emitNativeCodeMatch(tree, name, depth);
398 return;
401 if (tree.isOperation()) {
402 emitOpMatch(tree, name, depth);
403 return;
406 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
409 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
410 std::string funcName = staticMatcherHelper.getMatcherName(tree);
411 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
412 opName);
414 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
415 // one pass.
417 // In general, bound symbol should have the unique name in the pattern but
418 // for the operand, binding same symbol to multiple operands imply a
419 // constraint at the same time. In this case, we will rename those operands
420 // with different names. As a result, we need to collect all the symbolInfos
421 // from the DagNode then get the updated name of the local variables from the
422 // global symbolInfoMap.
424 // Collect all the bound symbols in the Dag
425 SymbolInfoMap localSymbolMap(loc);
426 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
428 for (const auto &info : localSymbolMap) {
429 auto name = info.first;
430 auto symboInfo = info.second;
431 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
432 os << formatv(", {0}", ret->second.getVarName(name));
435 os << "))) {\n";
436 os.scope().os << "return ::mlir::failure();\n";
437 os << "}\n";
440 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
441 StringRef opName, StringRef arg,
442 StringRef failureStr) {
443 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
444 funcName, opName, arg, failureStr);
445 os.scope().os << "return ::mlir::failure();\n";
446 os << "}\n";
449 // Helper function to match patterns.
450 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
451 int depth) {
452 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
453 LLVM_DEBUG(tree.print(llvm::dbgs()));
454 LLVM_DEBUG(llvm::dbgs() << '\n');
456 // The order of generating static matcher follows the topological order so
457 // that for every dependent DagNode already have their static matcher
458 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
459 // is when we are generating the static matcher for a DagNode itself. In this
460 // case, we need to emit the function body rather than a function call.
461 if (staticMatcherHelper.useStaticMatcher(tree) &&
462 !staticMatcherHelper.getMatcherName(tree).empty()) {
463 emitStaticMatchCall(tree, opName);
465 // NativeCodeCall will never be at depth 0 so that we don't need to catch
466 // the root operation as emitOpMatch();
468 return;
471 // TODO(suderman): iterate through arguments, determine their types, output
472 // names.
473 SmallVector<std::string, 8> capture;
475 raw_indented_ostream::DelimitedScope scope(os);
477 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
478 std::string argName = formatv("arg{0}_{1}", depth, i);
479 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
480 if (argTree.isEither())
481 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
482 if (argTree.isVariadic())
483 PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
485 os << "::mlir::Value " << argName << ";\n";
486 } else {
487 auto leaf = tree.getArgAsLeaf(i);
488 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
489 os << "::mlir::Attribute " << argName << ";\n";
490 } else {
491 os << "::mlir::Value " << argName << ";\n";
495 capture.push_back(std::move(argName));
498 auto tail = getTrailingDirectives(tree);
499 if (tail.returnType)
500 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
501 auto locToUse = getLocation(tail);
503 auto fmt = tree.getNativeCodeTemplate();
504 if (fmt.count("$_self") != 1)
505 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
506 "passing the defining Operation");
508 auto nativeCodeCall = std::string(
509 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
510 static_cast<ArrayRef<std::string>>(capture)));
512 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
513 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
515 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
516 auto name = tree.getArgName(i);
517 if (!name.empty() && name != "_") {
518 os << formatv("{0} = {1};\n", name, capture[i]);
522 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
523 std::string argName = capture[i];
525 // Handle nested DAG construct first
526 if (tree.getArgAsNestedDag(i)) {
527 PrintFatalError(
528 loc, formatv("Matching nested tree in NativeCodecall not support for "
529 "{0} as arg {1}",
530 argName, i));
533 DagLeaf leaf = tree.getArgAsLeaf(i);
535 // The parameter for native function doesn't bind any constraints.
536 if (leaf.isUnspecified())
537 continue;
539 auto constraint = leaf.getAsConstraint();
541 std::string self;
542 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
543 self = argName;
544 else
545 self = formatv("{0}.getType()", argName);
546 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
547 emitStaticVerifierCall(
548 verifier, opName, self,
549 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
550 "constraint: "
551 "'{2}'\"",
552 i, tree.getNativeCodeTemplate(),
553 escapeString(constraint.getSummary()))
554 .str());
557 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
560 // Helper function to match patterns.
561 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
562 Operator &op = tree.getDialectOp(opMap);
563 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
564 << op.getOperationName() << "' at depth " << depth
565 << '\n');
567 auto getCastedName = [depth]() -> std::string {
568 return formatv("castedOp{0}", depth);
571 // The order of generating static matcher follows the topological order so
572 // that for every dependent DagNode already have their static matcher
573 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
574 // is when we are generating the static matcher for a DagNode itself. In this
575 // case, we need to emit the function body rather than a function call.
576 if (staticMatcherHelper.useStaticMatcher(tree) &&
577 !staticMatcherHelper.getMatcherName(tree).empty()) {
578 emitStaticMatchCall(tree, opName);
579 // In the codegen of rewriter, we suppose that castedOp0 will capture the
580 // root operation. Manually add it if the root DagNode is a static matcher.
581 if (depth == 0)
582 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
583 "(void){2};\n",
584 opName, op.getQualCppClassName(), getCastedName());
585 return;
588 std::string castedName = getCastedName();
589 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
590 "(void){0};\n",
591 castedName, opName, op.getQualCppClassName());
593 // Skip the operand matching at depth 0 as the pattern rewriter already does.
594 if (depth != 0)
595 emitMatchCheck(opName, /*matchStr=*/castedName,
596 formatv("\"{0} is not {1} type\"", castedName,
597 op.getQualCppClassName()));
599 // If the operand's name is set, set to that variable.
600 auto name = tree.getSymbol();
601 if (!name.empty())
602 os << formatv("{0} = {1};\n", name, castedName);
604 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
605 ++i, ++opArgIdx) {
606 auto opArg = op.getArg(opArgIdx);
607 std::string argName = formatv("op{0}", depth + 1);
609 // Handle nested DAG construct first
610 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
611 if (argTree.isEither()) {
612 emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
613 depth);
614 ++opArgIdx;
615 continue;
617 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
618 if (argTree.isVariadic()) {
619 if (!operand->isVariadic()) {
620 auto error = formatv("variadic DAG construct can't match op {0}'s "
621 "non-variadic operand #{1}",
622 op.getOperationName(), opArgIdx);
623 PrintFatalError(loc, error);
625 emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
626 nextOperand, depth);
627 ++nextOperand;
628 continue;
630 if (operand->isVariableLength()) {
631 auto error = formatv("use nested DAG construct to match op {0}'s "
632 "variadic operand #{1} unsupported now",
633 op.getOperationName(), opArgIdx);
634 PrintFatalError(loc, error);
638 os << "{\n";
640 // Attributes don't count for getODSOperands.
641 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
642 os.indent() << formatv(
643 "auto *{0} = "
644 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
645 argName, castedName, nextOperand);
646 // Null check of operand's definingOp
647 emitMatchCheck(
648 castedName, /*matchStr=*/argName,
649 formatv("\"There's no operation that defines operand {0} of {1}\"",
650 nextOperand++, castedName));
651 emitMatch(argTree, argName, depth + 1);
652 os << formatv("tblgen_ops.push_back({0});\n", argName);
653 os.unindent() << "}\n";
654 continue;
657 // Next handle DAG leaf: operand or attribute
658 if (opArg.is<NamedTypeConstraint *>()) {
659 auto operandName =
660 formatv("{0}.getODSOperands({1})", castedName, nextOperand);
661 emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
662 /*operandMatcher=*/tree.getArgAsLeaf(i),
663 /*argName=*/tree.getArgName(i), opArgIdx,
664 /*variadicSubIndex=*/std::nullopt);
665 ++nextOperand;
666 } else if (opArg.is<NamedAttribute *>()) {
667 emitAttributeMatch(tree, opName, opArgIdx, depth);
668 } else {
669 PrintFatalError(loc, "unhandled case when matching op");
672 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
673 << op.getOperationName() << "' at depth " << depth
674 << '\n');
677 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
678 StringRef operandName, int operandIndex,
679 DagLeaf operandMatcher, StringRef argName,
680 int argIndex,
681 std::optional<int> variadicSubIndex) {
682 Operator &op = tree.getDialectOp(opMap);
683 auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
685 // If a constraint is specified, we need to generate C++ statements to
686 // check the constraint.
687 if (!operandMatcher.isUnspecified()) {
688 if (!operandMatcher.isOperandMatcher())
689 PrintFatalError(
690 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
691 op.getOperationName(), argIndex + 1));
693 // Only need to verify if the matcher's type is different from the one
694 // of op definition.
695 Constraint constraint = operandMatcher.getAsConstraint();
696 if (operand->constraint != constraint) {
697 if (operand->isVariableLength()) {
698 auto error = formatv(
699 "further constrain op {0}'s variadic operand #{1} unsupported now",
700 op.getOperationName(), argIndex);
701 PrintFatalError(loc, error);
703 auto self = formatv("(*{0}.begin()).getType()", operandName);
704 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
705 emitStaticVerifierCall(
706 verifier, opName, self.str(),
707 formatv(
708 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
709 operand - op.operand_begin(), op.getOperationName(),
710 escapeString(constraint.getSummary()))
711 .str());
715 // Capture the value
716 // `$_` is a special symbol to ignore op argument matching.
717 if (!argName.empty() && argName != "_") {
718 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
719 variadicSubIndex);
720 if (res == symbolInfoMap.end())
721 PrintFatalError(loc, formatv("symbol not found: {0}", argName));
723 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
727 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
728 StringRef opName, int argIndex,
729 int &operandIndex, int depth) {
730 constexpr int numEitherArgs = 2;
731 if (eitherArgTree.getNumArgs() != numEitherArgs)
732 PrintFatalError(loc, "`either` only supports grouping two operands");
734 Operator &op = tree.getDialectOp(opMap);
736 std::string codeBuffer;
737 llvm::raw_string_ostream tblgenOps(codeBuffer);
739 std::string lambda = formatv("eitherLambda{0}", depth);
740 os << formatv(
741 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
742 lambda);
744 os.indent();
746 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
747 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
748 if (argTree.isEither())
749 PrintFatalError(loc, "either cannot be nested");
751 std::string argName = formatv("local_op_{0}", i).str();
753 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
756 // Indent emitMatchCheck and emitMatch because they declare local
757 // variables.
758 os << "{\n";
759 os.indent();
761 emitMatchCheck(
762 opName, /*matchStr=*/argName,
763 formatv("\"There's no operation that defines operand {0} of {1}\"",
764 operandIndex++, opName));
765 emitMatch(argTree, argName, depth + 1);
767 os.unindent() << "}\n";
769 // `tblgen_ops` is used to collect the matched operations. In either, we
770 // need to queue the operation only if the matching success. Thus we emit
771 // the code at the end.
772 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
773 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
774 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
775 operandIndex,
776 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
777 /*argName=*/eitherArgTree.getArgName(i), argIndex,
778 /*variadicSubIndex=*/std::nullopt);
779 ++operandIndex;
780 } else {
781 PrintFatalError(loc, "either can only be applied on operand");
785 os << tblgenOps.str();
786 os << "return ::mlir::success();\n";
787 os.unindent() << "};\n";
789 os << "{\n";
790 os.indent();
792 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
793 operandIndex - 2);
794 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
795 operandIndex - 1);
797 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
798 "::mlir::failed({0}(eitherOperand1, "
799 "eitherOperand0)))\n",
800 lambda);
801 os.indent() << "return ::mlir::failure();\n";
803 os.unindent().unindent() << "}\n";
806 void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
807 DagNode variadicArgTree,
808 StringRef opName, int argIndex,
809 int &operandIndex, int depth) {
810 Operator &op = tree.getDialectOp(opMap);
812 os << "{\n";
813 os.indent();
815 os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
816 opName, operandIndex);
817 os << formatv("if (variadic_operand_range.size() != {0}) "
818 "return ::mlir::failure();\n",
819 variadicArgTree.getNumArgs());
821 StringRef variadicTreeName = variadicArgTree.getSymbol();
822 if (!variadicTreeName.empty()) {
823 auto res =
824 symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
825 /*variadicSubIndex=*/std::nullopt);
826 if (res == symbolInfoMap.end())
827 PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
829 os << formatv("{0} = variadic_operand_range;\n",
830 res->second.getVarName(variadicTreeName));
833 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
834 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
835 if (!argTree.isOperation())
836 PrintFatalError(loc, "variadic only accepts operation sub-dags");
838 os << "{\n";
839 os.indent();
841 std::string argName = formatv("local_op_{0}", i).str();
842 os << formatv("auto *{0} = "
843 "variadic_operand_range[{1}].getDefiningOp();\n",
844 argName, i);
845 emitMatchCheck(
846 opName, /*matchStr=*/argName,
847 formatv("\"There's no operation that defines variadic operand "
848 "{0} (variadic sub-opearnd #{1}) of {2}\"",
849 operandIndex, i, opName));
850 emitMatch(argTree, argName, depth + 1);
851 os << formatv("tblgen_ops.push_back({0});\n", argName);
853 os.unindent() << "}\n";
854 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
855 auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
856 emitOperandMatch(tree, opName, operandName.str(), operandIndex,
857 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
858 /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
859 } else {
860 PrintFatalError(loc, "variadic can only be applied on operand");
864 os.unindent() << "}\n";
867 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
868 int argIndex, int depth) {
869 Operator &op = tree.getDialectOp(opMap);
870 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
871 const auto &attr = namedAttr->attr;
873 os << "{\n";
874 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
875 "(void)tblgen_attr;\n",
876 opName, attr.getStorageType(), namedAttr->name);
878 // TODO: This should use getter method to avoid duplication.
879 if (attr.hasDefaultValue()) {
880 os << "if (!tblgen_attr) tblgen_attr = "
881 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
882 attr.getDefaultValue()))
883 << ";\n";
884 } else if (attr.isOptional()) {
885 // For a missing attribute that is optional according to definition, we
886 // should just capture a mlir::Attribute() to signal the missing state.
887 // That is precisely what getDiscardableAttr() returns on missing
888 // attributes.
889 } else {
890 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
891 formatv("\"expected op '{0}' to have attribute '{1}' "
892 "of type '{2}'\"",
893 op.getOperationName(), namedAttr->name,
894 attr.getStorageType()));
897 auto matcher = tree.getArgAsLeaf(argIndex);
898 if (!matcher.isUnspecified()) {
899 if (!matcher.isAttrMatcher()) {
900 PrintFatalError(
901 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
902 op.getOperationName(), argIndex + 1));
905 // If a constraint is specified, we need to generate function call to its
906 // static verifier.
907 StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
908 if (attr.isOptional()) {
909 // Avoid dereferencing null attribute. This is using a simple heuristic to
910 // avoid common cases of attempting to dereference null attribute. This
911 // will return where there is no check if attribute is null unless the
912 // attribute's value is not used.
913 // FIXME: This could be improved as some null dereferences could slip
914 // through.
915 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
916 StringRef(matcher.getConditionTemplate()).contains("$_self")) {
917 os << "if (!tblgen_attr) return ::mlir::failure();\n";
920 emitStaticVerifierCall(
921 verifier, opName, "tblgen_attr",
922 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
923 "'{2}'\"",
924 op.getOperationName(), namedAttr->name,
925 escapeString(matcher.getAsConstraint().getSummary()))
926 .str());
929 // Capture the value
930 auto name = tree.getArgName(argIndex);
931 // `$_` is a special symbol to ignore op argument matching.
932 if (!name.empty() && name != "_") {
933 os << formatv("{0} = tblgen_attr;\n", name);
936 os.unindent() << "}\n";
939 void PatternEmitter::emitMatchCheck(
940 StringRef opName, const FmtObjectBase &matchFmt,
941 const llvm::formatv_object_base &failureFmt) {
942 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
945 void PatternEmitter::emitMatchCheck(StringRef opName,
946 const std::string &matchStr,
947 const std::string &failureStr) {
949 os << "if (!(" << matchStr << "))";
950 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
951 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
952 << failureStr << ";\n});";
955 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
956 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
957 int depth = 0;
958 emitMatch(tree, opName, depth);
960 for (auto &appliedConstraint : pattern.getConstraints()) {
961 auto &constraint = appliedConstraint.constraint;
962 auto &entities = appliedConstraint.entities;
964 auto condition = constraint.getConditionTemplate();
965 if (isa<TypeConstraint>(constraint)) {
966 if (entities.size() != 1)
967 PrintFatalError(loc, "type constraint requires exactly one argument");
969 auto self = formatv("({0}.getType())",
970 symbolInfoMap.getValueAndRangeUse(entities.front()));
971 emitMatchCheck(
972 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
973 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
974 entities.front(), escapeString(constraint.getSummary())));
976 } else if (isa<AttrConstraint>(constraint)) {
977 PrintFatalError(
978 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
979 } else {
980 // TODO: replace formatv arguments with the exact specified
981 // args.
982 if (entities.size() > 4) {
983 PrintFatalError(loc, "only support up to 4-entity constraints now");
985 SmallVector<std::string, 4> names;
986 int i = 0;
987 for (int e = entities.size(); i < e; ++i)
988 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
989 std::string self = appliedConstraint.self;
990 if (!self.empty())
991 self = symbolInfoMap.getValueAndRangeUse(self);
992 for (; i < 4; ++i)
993 names.push_back("<unused>");
994 emitMatchCheck(opName,
995 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
996 names[1], names[2], names[3]),
997 formatv("\"entities '{0}' failed to satisfy constraint: "
998 "'{1}'\"",
999 llvm::join(entities, ", "),
1000 escapeString(constraint.getSummary())));
1004 // Some of the operands could be bound to the same symbol name, we need
1005 // to enforce equality constraint on those.
1006 // TODO: we should be able to emit equality checks early
1007 // and short circuit unnecessary work if vars are not equal.
1008 for (auto symbolInfoIt = symbolInfoMap.begin();
1009 symbolInfoIt != symbolInfoMap.end();) {
1010 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1011 auto startRange = range.first;
1012 auto endRange = range.second;
1014 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1015 for (++startRange; startRange != endRange; ++startRange) {
1016 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1017 emitMatchCheck(
1018 opName,
1019 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1020 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1021 secondOperand));
1024 symbolInfoIt = endRange;
1027 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1030 void PatternEmitter::collectOps(DagNode tree,
1031 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1032 // Check if this tree is an operation.
1033 if (tree.isOperation()) {
1034 const Operator &op = tree.getDialectOp(opMap);
1035 LLVM_DEBUG(llvm::dbgs()
1036 << "found operation " << op.getOperationName() << '\n');
1037 ops.insert(&op);
1040 // Recurse the arguments of the tree.
1041 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1042 if (auto child = tree.getArgAsNestedDag(i))
1043 collectOps(child, ops);
1046 void PatternEmitter::emit(StringRef rewriteName) {
1047 // Get the DAG tree for the source pattern.
1048 DagNode sourceTree = pattern.getSourcePattern();
1050 const Operator &rootOp = pattern.getSourceRootOp();
1051 auto rootName = rootOp.getOperationName();
1053 // Collect the set of result operations.
1054 llvm::SmallPtrSet<const Operator *, 4> resultOps;
1055 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1056 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1057 collectOps(pattern.getResultPattern(i), resultOps);
1059 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1061 // Emit RewritePattern for Pattern.
1062 auto locs = pattern.getLocation();
1063 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
1064 make_range(locs.rbegin(), locs.rend()));
1065 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
1066 {0}(::mlir::MLIRContext *context)
1067 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1068 rewriteName, rootName, pattern.getBenefit());
1069 // Sort result operators by name.
1070 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1071 resultOps.end());
1072 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
1073 return lhs->getOperationName() < rhs->getOperationName();
1075 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
1076 os << '"' << op->getOperationName() << '"';
1078 os << "}) {}\n";
1080 // Emit matchAndRewrite() function.
1082 auto classScope = os.scope();
1083 os.printReindented(R"(
1084 ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1085 ::mlir::PatternRewriter &rewriter) const override {)")
1086 << '\n';
1088 auto functionScope = os.scope();
1090 // Register all symbols bound in the source pattern.
1091 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
1093 LLVM_DEBUG(llvm::dbgs()
1094 << "start creating local variables for capturing matches\n");
1095 os << "// Variables for capturing values and attributes used while "
1096 "creating ops\n";
1097 // Create local variables for storing the arguments and results bound
1098 // to symbols.
1099 for (const auto &symbolInfoPair : symbolInfoMap) {
1100 const auto &symbol = symbolInfoPair.first;
1101 const auto &info = symbolInfoPair.second;
1103 os << info.getVarDecl(symbol);
1105 // TODO: capture ops with consistent numbering so that it can be
1106 // reused for fused loc.
1107 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1108 LLVM_DEBUG(llvm::dbgs()
1109 << "done creating local variables for capturing matches\n");
1111 os << "// Match\n";
1112 os << "tblgen_ops.push_back(op0);\n";
1113 emitMatchLogic(sourceTree, "op0");
1115 os << "\n// Rewrite\n";
1116 emitRewriteLogic();
1118 os << "return ::mlir::success();\n";
1120 os << "}\n";
1122 os << "};\n\n";
1125 void PatternEmitter::emitRewriteLogic() {
1126 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1127 const Operator &rootOp = pattern.getSourceRootOp();
1128 int numExpectedResults = rootOp.getNumResults();
1129 int numResultPatterns = pattern.getNumResultPatterns();
1131 // First register all symbols bound to ops generated in result patterns.
1132 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1134 // Only the last N static values generated are used to replace the matched
1135 // root N-result op. We need to calculate the starting index (of the results
1136 // of the matched op) each result pattern is to replace.
1137 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1138 // If we don't need to replace any value at all, set the replacement starting
1139 // index as the number of result patterns so we skip all of them when trying
1140 // to replace the matched op's results.
1141 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1142 for (int i = numResultPatterns - 1; i >= 0; --i) {
1143 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1144 offsets[i] = offsets[i + 1] - numValues;
1145 if (offsets[i] == 0) {
1146 if (replStartIndex == -1)
1147 replStartIndex = i;
1148 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1149 auto error = formatv(
1150 "cannot use the same multi-result op '{0}' to generate both "
1151 "auxiliary values and values to be used for replacing the matched op",
1152 pattern.getResultPattern(i).getSymbol());
1153 PrintFatalError(loc, error);
1157 if (offsets.front() > 0) {
1158 const char error[] =
1159 "not enough values generated to replace the matched op";
1160 PrintFatalError(loc, error);
1163 os << "auto odsLoc = rewriter.getFusedLoc({";
1164 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1165 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1167 os << "}); (void)odsLoc;\n";
1169 // Process auxiliary result patterns.
1170 for (int i = 0; i < replStartIndex; ++i) {
1171 DagNode resultTree = pattern.getResultPattern(i);
1172 auto val = handleResultPattern(resultTree, offsets[i], 0);
1173 // Normal op creation will be streamed to `os` by the above call; but
1174 // NativeCodeCall will only be materialized to `os` if it is used. Here
1175 // we are handling auxiliary patterns so we want the side effect even if
1176 // NativeCodeCall is not replacing matched root op's results.
1177 if (resultTree.isNativeCodeCall() &&
1178 resultTree.getNumReturnsOfNativeCode() == 0)
1179 os << val << ";\n";
1182 auto processSupplementalPatterns = [&]() {
1183 int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1184 for (int i = 0, offset = -numSupplementalPatterns;
1185 i < numSupplementalPatterns; ++i) {
1186 DagNode resultTree = pattern.getSupplementalPattern(i);
1187 auto val = handleResultPattern(resultTree, offset++, 0);
1188 if (resultTree.isNativeCodeCall() &&
1189 resultTree.getNumReturnsOfNativeCode() == 0)
1190 os << val << ";\n";
1194 if (numExpectedResults == 0) {
1195 assert(replStartIndex >= numResultPatterns &&
1196 "invalid auxiliary vs. replacement pattern division!");
1197 processSupplementalPatterns();
1198 // No result to replace. Just erase the op.
1199 os << "rewriter.eraseOp(op0);\n";
1200 } else {
1201 // Process replacement result patterns.
1202 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1203 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1204 DagNode resultTree = pattern.getResultPattern(i);
1205 auto val = handleResultPattern(resultTree, offsets[i], 0);
1206 os << "\n";
1207 // Resolve each symbol for all range use so that we can loop over them.
1208 // We need an explicit cast to `SmallVector` to capture the cases where
1209 // `{0}` resolves to an `Operation::result_range` as well as cases that
1210 // are not iterable (e.g. vector that gets wrapped in additional braces by
1211 // RewriterGen).
1212 // TODO: Revisit the need for materializing a vector.
1213 os << symbolInfoMap.getAllRangeUse(
1214 val,
1215 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1216 " tblgen_repl_values.push_back(v);\n}\n",
1217 "\n");
1219 processSupplementalPatterns();
1220 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1223 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1226 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1227 return std::string(
1228 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1231 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1232 int resultIndex, int depth) {
1233 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1234 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1235 LLVM_DEBUG(llvm::dbgs() << '\n');
1237 if (resultTree.isLocationDirective()) {
1238 PrintFatalError(loc,
1239 "location directive can only be used with op creation");
1242 if (resultTree.isNativeCodeCall())
1243 return handleReplaceWithNativeCodeCall(resultTree, depth);
1245 if (resultTree.isReplaceWithValue())
1246 return handleReplaceWithValue(resultTree).str();
1248 if (resultTree.isVariadic())
1249 return handleVariadic(resultTree, depth);
1251 // Normal op creation.
1252 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1253 if (resultTree.getSymbol().empty()) {
1254 // This is an op not explicitly bound to a symbol in the rewrite rule.
1255 // Register the auto-generated symbol for it.
1256 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1258 return symbol;
1261 std::string PatternEmitter::handleVariadic(DagNode tree, int depth) {
1262 assert(tree.isVariadic());
1264 std::string output;
1265 llvm::raw_string_ostream oss(output);
1266 auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++));
1267 symbolInfoMap.bindValue(name);
1268 oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n";
1269 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1270 if (auto child = tree.getArgAsNestedDag(i)) {
1271 oss << name << ".push_back(" << handleResultPattern(child, i, depth + 1)
1272 << ");\n";
1273 } else {
1274 oss << name << ".push_back("
1275 << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))
1276 << ");\n";
1280 os << oss.str();
1281 return name;
1284 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1285 assert(tree.isReplaceWithValue());
1287 if (tree.getNumArgs() != 1) {
1288 PrintFatalError(
1289 loc, "replaceWithValue directive must take exactly one argument");
1292 if (!tree.getSymbol().empty()) {
1293 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1296 return tree.getArgName(0);
1299 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1300 assert(tree.isLocationDirective());
1301 auto lookUpArgLoc = [this, &tree](int idx) {
1302 const auto *const lookupFmt = "{0}.getLoc()";
1303 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1306 if (tree.getNumArgs() == 0)
1307 llvm::PrintFatalError(
1308 "At least one argument to location directive required");
1310 if (!tree.getSymbol().empty())
1311 PrintFatalError(loc, "cannot bind symbol to location");
1313 if (tree.getNumArgs() == 1) {
1314 DagLeaf leaf = tree.getArgAsLeaf(0);
1315 if (leaf.isStringAttr())
1316 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1317 leaf.getStringAttr())
1318 .str();
1319 return lookUpArgLoc(0);
1322 std::string ret;
1323 llvm::raw_string_ostream os(ret);
1324 std::string strAttr;
1325 os << "rewriter.getFusedLoc({";
1326 bool first = true;
1327 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1328 DagLeaf leaf = tree.getArgAsLeaf(i);
1329 // Handle the optional string value.
1330 if (leaf.isStringAttr()) {
1331 if (!strAttr.empty())
1332 llvm::PrintFatalError("Only one string attribute may be specified");
1333 strAttr = leaf.getStringAttr();
1334 continue;
1336 os << (first ? "" : ", ") << lookUpArgLoc(i);
1337 first = false;
1339 os << "}";
1340 if (!strAttr.empty()) {
1341 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1343 os << ")";
1344 return os.str();
1347 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1348 int depth) {
1349 // Nested NativeCodeCall.
1350 if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1351 if (!dagNode.isNativeCodeCall())
1352 PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1353 "call");
1354 return handleReplaceWithNativeCodeCall(dagNode, depth);
1356 // String literal.
1357 auto dagLeaf = returnType.getArgAsLeaf(i);
1358 if (dagLeaf.isStringAttr())
1359 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1360 return tgfmt(
1361 "$0.getType()", &fmtCtx,
1362 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1365 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1366 StringRef patArgName) {
1367 if (leaf.isStringAttr())
1368 PrintFatalError(loc, "raw string not supported as argument");
1369 if (leaf.isConstantAttr()) {
1370 auto constAttr = leaf.getAsConstantAttr();
1371 return handleConstantAttr(constAttr.getAttribute(),
1372 constAttr.getConstantValue());
1374 if (leaf.isEnumAttrCase()) {
1375 auto enumCase = leaf.getAsEnumAttrCase();
1376 // This is an enum case backed by an IntegerAttr. We need to get its value
1377 // to build the constant.
1378 std::string val = std::to_string(enumCase.getValue());
1379 return handleConstantAttr(enumCase, val);
1382 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1383 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1384 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1385 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1386 << "' (via symbol ref)\n");
1387 return argName;
1389 if (leaf.isNativeCodeCall()) {
1390 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1391 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1392 << "' (via NativeCodeCall)\n");
1393 return std::string(repl);
1395 PrintFatalError(loc, "unhandled case when rewriting op");
1398 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1399 int depth) {
1400 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1401 LLVM_DEBUG(tree.print(llvm::dbgs()));
1402 LLVM_DEBUG(llvm::dbgs() << '\n');
1404 auto fmt = tree.getNativeCodeTemplate();
1406 SmallVector<std::string, 16> attrs;
1408 auto tail = getTrailingDirectives(tree);
1409 if (tail.returnType)
1410 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1411 auto locToUse = getLocation(tail);
1413 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1414 if (tree.isNestedDagArg(i)) {
1415 attrs.push_back(
1416 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1417 } else {
1418 attrs.push_back(
1419 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1421 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1422 << " replacement: " << attrs[i] << "\n");
1425 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1426 static_cast<ArrayRef<std::string>>(attrs));
1428 // In general, NativeCodeCall without naming binding don't need this. To
1429 // ensure void helper function has been correctly labeled, i.e., use
1430 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1431 // get a compilation error in the auto-generated file.
1432 // Example.
1433 // // In the td file
1434 // Pat<(...), (NativeCodeCall<Foo> ...)>
1436 // ---
1438 // // In the auto-generated .cpp
1439 // ...
1440 // // Causes compilation error if Foo() returns void.
1441 // auto nativeVar = Foo();
1442 // ...
1443 if (tree.getNumReturnsOfNativeCode() != 0) {
1444 // Determine the local variable name for return value.
1445 std::string varName =
1446 SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1447 if (varName.empty()) {
1448 varName = formatv("nativeVar_{0}", nextValueId++);
1449 // Register the local variable for later uses.
1450 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1453 // Catch the return value of helper function.
1454 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1456 if (!tree.getSymbol().empty())
1457 symbol = tree.getSymbol().str();
1458 else
1459 symbol = varName;
1462 return symbol;
1465 int PatternEmitter::getNodeValueCount(DagNode node) {
1466 if (node.isOperation()) {
1467 // If the op is bound to a symbol in the rewrite rule, query its result
1468 // count from the symbol info map.
1469 auto symbol = node.getSymbol();
1470 if (!symbol.empty()) {
1471 return symbolInfoMap.getStaticValueCount(symbol);
1473 // Otherwise this is an unbound op; we will use all its results.
1474 return pattern.getDialectOp(node).getNumResults();
1477 if (node.isNativeCodeCall())
1478 return node.getNumReturnsOfNativeCode();
1480 return 1;
1483 PatternEmitter::TrailingDirectives
1484 PatternEmitter::getTrailingDirectives(DagNode tree) {
1485 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1487 // Look backwards through the arguments.
1488 auto numPatArgs = tree.getNumArgs();
1489 for (int i = numPatArgs - 1; i >= 0; --i) {
1490 auto dagArg = tree.getArgAsNestedDag(i);
1491 // A leaf is not a directive. Stop looking.
1492 if (!dagArg)
1493 break;
1495 auto isLocation = dagArg.isLocationDirective();
1496 auto isReturnType = dagArg.isReturnTypeDirective();
1497 // If encountered a DAG node that isn't a trailing directive, stop looking.
1498 if (!(isLocation || isReturnType))
1499 break;
1500 // Save the directive, but error if one of the same type was already
1501 // found.
1502 ++tail.numDirectives;
1503 if (isLocation) {
1504 if (tail.location)
1505 PrintFatalError(loc, "`location` directive can only be specified "
1506 "once");
1507 tail.location = dagArg;
1508 } else if (isReturnType) {
1509 if (tail.returnType)
1510 PrintFatalError(loc, "`returnType` directive can only be specified "
1511 "once");
1512 tail.returnType = dagArg;
1516 return tail;
1519 std::string
1520 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1521 if (tail.location)
1522 return handleLocationDirective(tail.location);
1524 // If no explicit location is given, use the default, all fused, location.
1525 return "odsLoc";
1528 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1529 int depth) {
1530 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1531 LLVM_DEBUG(tree.print(llvm::dbgs()));
1532 LLVM_DEBUG(llvm::dbgs() << '\n');
1534 Operator &resultOp = tree.getDialectOp(opMap);
1535 auto numOpArgs = resultOp.getNumArgs();
1536 auto numPatArgs = tree.getNumArgs();
1538 auto tail = getTrailingDirectives(tree);
1539 auto locToUse = getLocation(tail);
1541 auto inPattern = numPatArgs - tail.numDirectives;
1542 if (numOpArgs != inPattern) {
1543 PrintFatalError(loc,
1544 formatv("resultant op '{0}' argument number mismatch: "
1545 "{1} in pattern vs. {2} in definition",
1546 resultOp.getOperationName(), inPattern, numOpArgs));
1549 // A map to collect all nested DAG child nodes' names, with operand index as
1550 // the key. This includes both bound and unbound child nodes.
1551 ChildNodeIndexNameMap childNodeNames;
1553 // If the argument is a type constraint, then its an operand. Check if the
1554 // op's argument is variadic that the argument in the pattern is too.
1555 auto checkIfMatchedVariadic = [&](int i) {
1556 // FIXME: This does not yet check for variable/leaf case.
1557 // FIXME: Change so that native code call can be handled.
1558 const auto *operand =
1559 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i));
1560 if (!operand || !operand->isVariadic())
1561 return;
1563 auto child = tree.getArgAsNestedDag(i);
1564 if (!child)
1565 return;
1567 // Skip over replaceWithValues.
1568 while (child.isReplaceWithValue()) {
1569 if (!(child = child.getArgAsNestedDag(0)))
1570 return;
1572 if (!child.isNativeCodeCall() && !child.isVariadic())
1573 PrintFatalError(loc, formatv("op expects variadic operand `{0}`, while "
1574 "provided is non-variadic",
1575 resultOp.getArgName(i)));
1578 // First go through all the child nodes who are nested DAG constructs to
1579 // create ops for them and remember the symbol names for them, so that we can
1580 // use the results in the current node. This happens in a recursive manner.
1581 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1582 checkIfMatchedVariadic(i);
1583 if (auto child = tree.getArgAsNestedDag(i))
1584 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1587 // The name of the local variable holding this op.
1588 std::string valuePackName;
1589 // The symbol for holding the result of this pattern. Note that the result of
1590 // this pattern is not necessarily the same as the variable created by this
1591 // pattern because we can use `__N` suffix to refer only a specific result if
1592 // the generated op is a multi-result op.
1593 std::string resultValue;
1594 if (tree.getSymbol().empty()) {
1595 // No symbol is explicitly bound to this op in the pattern. Generate a
1596 // unique name.
1597 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1598 } else {
1599 resultValue = std::string(tree.getSymbol());
1600 // Strip the index to get the name for the value pack and use it to name the
1601 // local variable for the op.
1602 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1605 // Create the local variable for this op.
1606 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1607 valuePackName);
1609 // Right now ODS don't have general type inference support. Except a few
1610 // special cases listed below, DRR needs to supply types for all results
1611 // when building an op.
1612 bool isSameOperandsAndResultType =
1613 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1614 bool useFirstAttr =
1615 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1617 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1618 // We know how to deduce the result type for ops with these traits and we've
1619 // generated builders taking aggregate parameters. Use those builders to
1620 // create the ops.
1622 // First prepare local variables for op arguments used in builder call.
1623 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1625 // Then create the op.
1626 os.scope("", "\n}\n").os << formatv(
1627 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1628 valuePackName, resultOp.getQualCppClassName(), locToUse);
1629 return resultValue;
1632 bool usePartialResults = valuePackName != resultValue;
1634 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1635 // For these cases (broadcastable ops, op results used both as auxiliary
1636 // values and replacement values, ops in nested patterns, auxiliary ops), we
1637 // still need to supply the result types when building the op. But because
1638 // we don't generate a builder automatically with ODS for them, it's the
1639 // developer's responsibility to make sure such a builder (with result type
1640 // deduction ability) exists. We go through the separate-parameter builder
1641 // here given that it's easier for developers to write compared to
1642 // aggregate-parameter builders.
1643 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1645 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1646 resultOp.getQualCppClassName(), locToUse);
1647 supplyValuesForOpArgs(tree, childNodeNames, depth);
1648 os << "\n );\n}\n";
1649 return resultValue;
1652 // If we are provided explicit return types, use them to build the op.
1653 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1654 // the values generated from the source pattern root op. Then we must use the
1655 // source pattern's value types to determine the value type of the generated
1656 // op here.
1657 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1658 PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1659 "return values replace the source pattern's root op");
1661 // First prepare local variables for op arguments used in builder call.
1662 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1664 // Then prepare the result types. We need to specify the types for all
1665 // results.
1666 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1667 "(void)tblgen_types;\n");
1668 int numResults = resultOp.getNumResults();
1669 if (tail.returnType) {
1670 auto numRetTys = tail.returnType.getNumArgs();
1671 for (int i = 0; i < numRetTys; ++i) {
1672 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1673 os << "tblgen_types.push_back(" << varName << ");\n";
1675 } else {
1676 if (numResults != 0) {
1677 // Copy the result types from the source pattern.
1678 for (int i = 0; i < numResults; ++i)
1679 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1680 " tblgen_types.push_back(v.getType());\n}\n",
1681 resultIndex + i);
1684 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1685 "tblgen_values, tblgen_attrs);\n",
1686 valuePackName, resultOp.getQualCppClassName(), locToUse);
1687 os.unindent() << "}\n";
1688 return resultValue;
1691 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1692 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1693 Operator &resultOp = node.getDialectOp(opMap);
1695 // Now prepare operands used for building this op:
1696 // * If the operand is non-variadic, we create a `Value` local variable.
1697 // * If the operand is variadic, we create a `SmallVector<Value>` local
1698 // variable.
1700 int valueIndex = 0; // An index for uniquing local variable names.
1701 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1702 const auto *operand =
1703 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
1704 // We do not need special handling for attributes.
1705 if (!operand)
1706 continue;
1708 raw_indented_ostream::DelimitedScope scope(os);
1709 std::string varName;
1710 if (operand->isVariadic()) {
1711 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1712 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1713 std::string range;
1714 if (node.isNestedDagArg(argIndex)) {
1715 range = childNodeNames[argIndex];
1716 } else {
1717 range = std::string(node.getArgName(argIndex));
1719 // Resolve the symbol for all range use so that we have a uniform way of
1720 // capturing the values.
1721 range = symbolInfoMap.getValueAndRangeUse(range);
1722 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1723 varName);
1724 } else {
1725 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1726 os << formatv("::mlir::Value {0} = ", varName);
1727 if (node.isNestedDagArg(argIndex)) {
1728 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1729 } else {
1730 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1731 auto symbol =
1732 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1733 if (leaf.isNativeCodeCall()) {
1734 os << std::string(
1735 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1736 } else {
1737 os << symbol;
1740 os << ";\n";
1743 // Update to use the newly created local variable for building the op later.
1744 childNodeNames[argIndex] = varName;
1748 void PatternEmitter::supplyValuesForOpArgs(
1749 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1750 Operator &resultOp = node.getDialectOp(opMap);
1751 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1752 argIndex != numOpArgs; ++argIndex) {
1753 // Start each argument on its own line.
1754 os << ",\n ";
1756 Argument opArg = resultOp.getArg(argIndex);
1757 // Handle the case of operand first.
1758 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
1759 if (!operand->name.empty())
1760 os << "/*" << operand->name << "=*/";
1761 os << childNodeNames.lookup(argIndex);
1762 continue;
1765 // The argument in the op definition.
1766 auto opArgName = resultOp.getArgName(argIndex);
1767 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1768 if (!subTree.isNativeCodeCall())
1769 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1770 "for creating attribute");
1771 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1772 } else {
1773 auto leaf = node.getArgAsLeaf(argIndex);
1774 // The argument in the result DAG pattern.
1775 auto patArgName = node.getArgName(argIndex);
1776 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1777 // TODO: Refactor out into map to avoid recomputing these.
1778 if (!opArg.is<NamedAttribute *>())
1779 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1780 if (!patArgName.empty())
1781 os << "/*" << patArgName << "=*/";
1782 } else {
1783 os << "/*" << opArgName << "=*/";
1785 os << handleOpArgument(leaf, patArgName);
1790 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1791 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1792 Operator &resultOp = node.getDialectOp(opMap);
1794 auto scope = os.scope();
1795 os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1796 "tblgen_values; (void)tblgen_values;\n");
1797 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1798 "tblgen_attrs; (void)tblgen_attrs;\n");
1800 const char *addAttrCmd =
1801 "if (auto tmpAttr = {1}) {\n"
1802 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1803 "tmpAttr);\n}\n";
1804 int numVariadic = 0;
1805 bool hasOperandSegmentSizes = false;
1806 std::vector<std::string> sizes;
1807 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1808 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1809 // The argument in the op definition.
1810 auto opArgName = resultOp.getArgName(argIndex);
1811 hasOperandSegmentSizes =
1812 hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1813 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1814 if (!subTree.isNativeCodeCall())
1815 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1816 "for creating attribute");
1817 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1818 } else {
1819 auto leaf = node.getArgAsLeaf(argIndex);
1820 // The argument in the result DAG pattern.
1821 auto patArgName = node.getArgName(argIndex);
1822 os << formatv(addAttrCmd, opArgName,
1823 handleOpArgument(leaf, patArgName));
1825 continue;
1828 const auto *operand =
1829 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1830 std::string varName;
1831 if (operand->isVariadic()) {
1832 ++numVariadic;
1833 std::string range;
1834 if (node.isNestedDagArg(argIndex)) {
1835 range = childNodeNames.lookup(argIndex);
1836 } else {
1837 range = std::string(node.getArgName(argIndex));
1839 // Resolve the symbol for all range use so that we have a uniform way of
1840 // capturing the values.
1841 range = symbolInfoMap.getValueAndRangeUse(range);
1842 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1843 range);
1844 sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
1845 } else {
1846 sizes.emplace_back("1");
1847 os << formatv("tblgen_values.push_back(");
1848 if (node.isNestedDagArg(argIndex)) {
1849 os << symbolInfoMap.getValueAndRangeUse(
1850 childNodeNames.lookup(argIndex));
1851 } else {
1852 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1853 if (leaf.isConstantAttr())
1854 // TODO: Use better location
1855 PrintFatalError(
1856 loc,
1857 "attribute found where value was expected, if attempting to use "
1858 "constant value, construct a constant op with given attribute "
1859 "instead");
1861 auto symbol =
1862 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1863 if (leaf.isNativeCodeCall()) {
1864 os << std::string(
1865 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1866 } else {
1867 os << symbol;
1870 os << ");\n";
1874 if (numVariadic > 1 && !hasOperandSegmentSizes) {
1875 // Only set size if it can't be computed.
1876 const auto *sameVariadicSize =
1877 resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1878 if (!sameVariadicSize) {
1879 const char *setSizes = R"(
1880 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1881 rewriter.getDenseI32ArrayAttr({{ {0} }));
1883 os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1888 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1889 const RecordKeeper &recordKeeper,
1890 RecordOperatorMap &mapper)
1891 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1893 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1894 // PatternEmitter will use the static matcher if there's one generated. To
1895 // ensure that all the dependent static matchers are generated before emitting
1896 // the matching logic of the DagNode, we use topological order to achieve it.
1897 for (auto &dagInfo : topologicalOrder) {
1898 DagNode node = dagInfo.first;
1899 if (!useStaticMatcher(node))
1900 continue;
1902 std::string funcName =
1903 formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1904 assert(!matcherNames.contains(node));
1905 PatternEmitter(dagInfo.second, &opMap, os, *this)
1906 .emitStaticMatcher(node, funcName);
1907 matcherNames[node] = funcName;
1911 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1912 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1915 void StaticMatcherHelper::addPattern(Record *record) {
1916 Pattern pat(record, &opMap);
1918 // While generating the function body of the DAG matcher, it may depends on
1919 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1920 // the topological order for all the DAGs and emit the DAG matchers in this
1921 // order.
1922 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1923 ++refStats[node];
1925 if (refStats[node] != 1)
1926 return;
1928 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1929 if (DagNode sibling = node.getArgAsNestedDag(i))
1930 dfs(sibling);
1931 else {
1932 DagLeaf leaf = node.getArgAsLeaf(i);
1933 if (!leaf.isUnspecified())
1934 constraints.insert(leaf);
1937 topologicalOrder.push_back(std::make_pair(node, record));
1940 dfs(pat.getSourcePattern());
1943 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1944 if (leaf.isAttrMatcher()) {
1945 std::optional<StringRef> constraint =
1946 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1947 assert(constraint && "attribute constraint was not uniqued");
1948 return *constraint;
1950 assert(leaf.isOperandMatcher());
1951 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1954 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1955 emitSourceFileHeader("Rewriters", os, recordKeeper);
1957 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1959 // We put the map here because it can be shared among multiple patterns.
1960 RecordOperatorMap recordOpMap;
1962 // Exam all the patterns and generate static matcher for the duplicated
1963 // DagNode.
1964 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1965 for (Record *p : patterns)
1966 staticMatcher.addPattern(p);
1967 staticMatcher.populateStaticConstraintFunctions(os);
1968 staticMatcher.populateStaticMatchers(os);
1970 std::vector<std::string> rewriterNames;
1971 rewriterNames.reserve(patterns.size());
1973 std::string baseRewriterName = "GeneratedConvert";
1974 int rewriterIndex = 0;
1976 for (Record *p : patterns) {
1977 std::string name;
1978 if (p->isAnonymous()) {
1979 // If no name is provided, ensure unique rewriter names simply by
1980 // appending unique suffix.
1981 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1982 } else {
1983 name = std::string(p->getName());
1985 LLVM_DEBUG(llvm::dbgs()
1986 << "=== start generating pattern '" << name << "' ===\n");
1987 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1988 LLVM_DEBUG(llvm::dbgs()
1989 << "=== done generating pattern '" << name << "' ===\n");
1990 rewriterNames.push_back(std::move(name));
1993 // Emit function to add the generated matchers to the pattern list.
1994 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1995 "::mlir::RewritePatternSet &patterns) {\n";
1996 for (const auto &name : rewriterNames) {
1997 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1999 os << "}\n";
2002 static mlir::GenRegistration
2003 genRewriters("gen-rewriters", "Generate pattern rewriters",
2004 [](const RecordKeeper &records, raw_ostream &os) {
2005 emitRewriters(records, os);
2006 return false;