[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / RewriterGen.cpp
blob77c34cb03e987ea248ad856ad01025497a334552
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Operator.h"
19 #include "mlir/TableGen/Pattern.h"
20 #include "mlir/TableGen/Predicate.h"
21 #include "mlir/TableGen/Type.h"
22 #include "llvm/ADT/FunctionExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/FormatAdapters.h"
29 #include "llvm/Support/PrettyStackTrace.h"
30 #include "llvm/Support/Signals.h"
31 #include "llvm/TableGen/Error.h"
32 #include "llvm/TableGen/Main.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
36 using namespace mlir;
37 using namespace mlir::tblgen;
39 using llvm::formatv;
40 using llvm::Record;
41 using llvm::RecordKeeper;
43 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
45 namespace llvm {
46 template <>
47 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
48 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
49 raw_ostream &os, StringRef style) {
50 os << v.first << ":" << v.second;
53 } // namespace llvm
55 //===----------------------------------------------------------------------===//
56 // PatternEmitter
57 //===----------------------------------------------------------------------===//
59 namespace {
61 class StaticMatcherHelper;
63 class PatternEmitter {
64 public:
65 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
66 StaticMatcherHelper &helper);
68 // Emits the mlir::RewritePattern struct named `rewriteName`.
69 void emit(StringRef rewriteName);
71 // Emits the static function of DAG matcher.
72 void emitStaticMatcher(DagNode tree, std::string funcName);
74 private:
75 // Emits the code for matching ops.
76 void emitMatchLogic(DagNode tree, StringRef opName);
78 // Emits the code for rewriting ops.
79 void emitRewriteLogic();
81 //===--------------------------------------------------------------------===//
82 // Match utilities
83 //===--------------------------------------------------------------------===//
85 // Emits C++ statements for matching the DAG structure.
86 void emitMatch(DagNode tree, StringRef name, int depth);
88 // Emit C++ function call to static DAG matcher.
89 void emitStaticMatchCall(DagNode tree, StringRef name);
91 // Emit C++ function call to static type/attribute constraint function.
92 void emitStaticVerifierCall(StringRef funcName, StringRef opName,
93 StringRef arg, StringRef failureStr);
95 // Emits C++ statements for matching using a native code call.
96 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
98 // Emits C++ statements for matching the op constrained by the given DAG
99 // `tree` returning the op's variable name.
100 void emitOpMatch(DagNode tree, StringRef opName, int depth);
102 // Emits C++ statements for matching the `argIndex`-th argument of the given
103 // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
104 // bound name and the constraint of the operand respectively.
105 void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
106 int operandIndex, DagLeaf operandMatcher,
107 StringRef argName, int argIndex,
108 std::optional<int> variadicSubIndex);
110 // Emits C++ statements for matching the operands which can be matched in
111 // either order.
112 void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
113 StringRef opName, int argIndex, int &operandIndex,
114 int depth);
116 // Emits C++ statements for matching a variadic operand.
117 void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree,
118 StringRef opName, int argIndex,
119 int &operandIndex, int depth);
121 // Emits C++ statements for matching the `argIndex`-th argument of the given
122 // DAG `tree` as an attribute.
123 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
124 int depth);
126 // Emits C++ for checking a match with a corresponding match failure
127 // diagnostic.
128 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
129 const llvm::formatv_object_base &failureFmt);
131 // Emits C++ for checking a match with a corresponding match failure
132 // diagnostics.
133 void emitMatchCheck(StringRef opName, const std::string &matchStr,
134 const std::string &failureStr);
136 //===--------------------------------------------------------------------===//
137 // Rewrite utilities
138 //===--------------------------------------------------------------------===//
140 // The entry point for handling a result pattern rooted at `resultTree`. This
141 // method dispatches to concrete handlers according to `resultTree`'s kind and
142 // returns a symbol representing the whole value pack. Callers are expected to
143 // further resolve the symbol according to the specific use case.
145 // `depth` is the nesting level of `resultTree`; 0 means top-level result
146 // pattern. For top-level result pattern, `resultIndex` indicates which result
147 // of the matched root op this pattern is intended to replace, which can be
148 // used to deduce the result type of the op generated from this result
149 // pattern.
150 std::string handleResultPattern(DagNode resultTree, int resultIndex,
151 int depth);
153 // Emits the C++ statement to replace the matched DAG with a value built via
154 // calling native C++ code.
155 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
157 // Returns the symbol of the old value serving as the replacement.
158 StringRef handleReplaceWithValue(DagNode tree);
160 // Trailing directives are used at the end of DAG node argument lists to
161 // specify additional behaviour for op matchers and creators, etc.
162 struct TrailingDirectives {
163 // DAG node containing the `location` directive. Null if there is none.
164 DagNode location;
166 // DAG node containing the `returnType` directive. Null if there is none.
167 DagNode returnType;
169 // Number of found trailing directives.
170 int numDirectives;
173 // Collect any trailing directives.
174 TrailingDirectives getTrailingDirectives(DagNode tree);
176 // Returns the location value to use.
177 std::string getLocation(TrailingDirectives &tail);
179 // Returns the location value to use.
180 std::string handleLocationDirective(DagNode tree);
182 // Emit return type argument.
183 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
185 // Emits the C++ statement to build a new op out of the given DAG `tree` and
186 // returns the variable name that this op is assigned to. If the root op in
187 // DAG `tree` has a specified name, the created op will be assigned to a
188 // variable of the given name. Otherwise, a unique name will be used as the
189 // result value name.
190 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
192 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
194 // Emits a local variable for each value and attribute to be used for creating
195 // an op.
196 void createSeparateLocalVarsForOpArgs(DagNode node,
197 ChildNodeIndexNameMap &childNodeNames);
199 // Emits the concrete arguments used to call an op's builder.
200 void supplyValuesForOpArgs(DagNode node,
201 const ChildNodeIndexNameMap &childNodeNames,
202 int depth);
204 // Emits the local variables for holding all values as a whole and all named
205 // attributes as a whole to be used for creating an op.
206 void createAggregateLocalVarsForOpArgs(
207 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
209 // Returns the C++ expression to construct a constant attribute of the given
210 // `value` for the given attribute kind `attr`.
211 std::string handleConstantAttr(Attribute attr, const Twine &value);
213 // Returns the C++ expression to build an argument from the given DAG `leaf`.
214 // `patArgName` is used to bound the argument to the source pattern.
215 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
217 //===--------------------------------------------------------------------===//
218 // General utilities
219 //===--------------------------------------------------------------------===//
221 // Collects all of the operations within the given dag tree.
222 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
224 // Returns a unique symbol for a local variable of the given `op`.
225 std::string getUniqueSymbol(const Operator *op);
227 //===--------------------------------------------------------------------===//
228 // Symbol utilities
229 //===--------------------------------------------------------------------===//
231 // Returns how many static values the given DAG `node` correspond to.
232 int getNodeValueCount(DagNode node);
234 private:
235 // Pattern instantiation location followed by the location of multiclass
236 // prototypes used. This is intended to be used as a whole to
237 // PrintFatalError() on errors.
238 ArrayRef<SMLoc> loc;
240 // Op's TableGen Record to wrapper object.
241 RecordOperatorMap *opMap;
243 // Handy wrapper for pattern being emitted.
244 Pattern pattern;
246 // Map for all bound symbols' info.
247 SymbolInfoMap symbolInfoMap;
249 StaticMatcherHelper &staticMatcherHelper;
251 // The next unused ID for newly created values.
252 unsigned nextValueId = 0;
254 raw_indented_ostream os;
256 // Format contexts containing placeholder substitutions.
257 FmtContext fmtCtx;
260 // Tracks DagNode's reference multiple times across patterns. Enables generating
261 // static matcher functions for DagNode's referenced multiple times rather than
262 // inlining them.
263 class StaticMatcherHelper {
264 public:
265 StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
266 RecordOperatorMap &mapper);
268 // Determine if we should inline the match logic or delegate to a static
269 // function.
270 bool useStaticMatcher(DagNode node) {
271 // either/variadic node must be associated to the parentOp, thus we can't
272 // emit a static matcher rooted at them.
273 if (node.isEither() || node.isVariadic())
274 return false;
276 return refStats[node] > kStaticMatcherThreshold;
279 // Get the name of the static DAG matcher function corresponding to the node.
280 std::string getMatcherName(DagNode node) {
281 assert(useStaticMatcher(node));
282 return matcherNames[node];
285 // Get the name of static type/attribute verification function.
286 StringRef getVerifierName(DagLeaf leaf);
288 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
289 // the duplicated DAGs.
290 void addPattern(Record *record);
292 // Emit all static functions of DAG Matcher.
293 void populateStaticMatchers(raw_ostream &os);
295 // Emit all static functions for Constraints.
296 void populateStaticConstraintFunctions(raw_ostream &os);
298 private:
299 static constexpr unsigned kStaticMatcherThreshold = 1;
301 // Consider two patterns as down below,
302 // DagNode_Root_A DagNode_Root_B
303 // \ \
304 // DagNode_C DagNode_C
305 // \ \
306 // DagNode_D DagNode_D
308 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
309 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
310 // multiple times so we'll have static matchers for both of them. When we're
311 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
312 // static matcher generated. If so, then we'll generate a call to the
313 // function, inline otherwise. In this case, inlining is not what we want. As
314 // a result, generate the static matcher in topological order to ensure all
315 // the dependent static matchers are generated and we can avoid accidentally
316 // inlining.
318 // The topological order of all the DagNodes among all patterns.
319 SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
321 RecordOperatorMap &opMap;
323 // Records of the static function name of each DagNode
324 DenseMap<DagNode, std::string> matcherNames;
326 // After collecting all the DagNode in each pattern, `refStats` records the
327 // number of users for each DagNode. We will generate the static matcher for a
328 // DagNode while the number of users exceeds a certain threshold.
329 DenseMap<DagNode, unsigned> refStats;
331 // Number of static matcher generated. This is used to generate a unique name
332 // for each DagNode.
333 int staticMatcherCounter = 0;
335 // The DagLeaf which contains type or attr constraint.
336 SetVector<DagLeaf> constraints;
338 // Static type/attribute verification function emitter.
339 StaticVerifierFunctionEmitter staticVerifierEmitter;
342 } // namespace
344 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
345 raw_ostream &os, StaticMatcherHelper &helper)
346 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
347 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) {
348 fmtCtx.withBuilder("rewriter");
351 std::string PatternEmitter::handleConstantAttr(Attribute attr,
352 const Twine &value) {
353 if (!attr.isConstBuildable())
354 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
355 " does not have the 'constBuilderCall' field");
357 // TODO: Verify the constants here
358 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
361 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
362 os << formatv(
363 "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
364 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
365 "*, 4> &tblgen_ops",
366 funcName);
368 // We pass the reference of the variables that need to be captured. Hence we
369 // need to collect all the symbols in the tree first.
370 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
371 symbolInfoMap.assignUniqueAlternativeNames();
372 for (const auto &info : symbolInfoMap)
373 os << formatv(", {0}", info.second.getArgDecl(info.first));
375 os << ") {\n";
376 os.indent();
377 os << "(void)tblgen_ops;\n";
379 // Note that a static matcher is considered at least one step from the match
380 // entry.
381 emitMatch(tree, "op0", /*depth=*/1);
383 os << "return ::mlir::success();\n";
384 os.unindent();
385 os << "}\n\n";
388 // Helper function to match patterns.
389 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
390 if (tree.isNativeCodeCall()) {
391 emitNativeCodeMatch(tree, name, depth);
392 return;
395 if (tree.isOperation()) {
396 emitOpMatch(tree, name, depth);
397 return;
400 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
403 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
404 std::string funcName = staticMatcherHelper.getMatcherName(tree);
405 os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName,
406 opName);
408 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
409 // one pass.
411 // In general, bound symbol should have the unique name in the pattern but
412 // for the operand, binding same symbol to multiple operands imply a
413 // constraint at the same time. In this case, we will rename those operands
414 // with different names. As a result, we need to collect all the symbolInfos
415 // from the DagNode then get the updated name of the local variables from the
416 // global symbolInfoMap.
418 // Collect all the bound symbols in the Dag
419 SymbolInfoMap localSymbolMap(loc);
420 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
422 for (const auto &info : localSymbolMap) {
423 auto name = info.first;
424 auto symboInfo = info.second;
425 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
426 os << formatv(", {0}", ret->second.getVarName(name));
429 os << "))) {\n";
430 os.scope().os << "return ::mlir::failure();\n";
431 os << "}\n";
434 void PatternEmitter::emitStaticVerifierCall(StringRef funcName,
435 StringRef opName, StringRef arg,
436 StringRef failureStr) {
437 os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n",
438 funcName, opName, arg, failureStr);
439 os.scope().os << "return ::mlir::failure();\n";
440 os << "}\n";
443 // Helper function to match patterns.
444 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
445 int depth) {
446 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
447 LLVM_DEBUG(tree.print(llvm::dbgs()));
448 LLVM_DEBUG(llvm::dbgs() << '\n');
450 // The order of generating static matcher follows the topological order so
451 // that for every dependent DagNode already have their static matcher
452 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
453 // is when we are generating the static matcher for a DagNode itself. In this
454 // case, we need to emit the function body rather than a function call.
455 if (staticMatcherHelper.useStaticMatcher(tree) &&
456 !staticMatcherHelper.getMatcherName(tree).empty()) {
457 emitStaticMatchCall(tree, opName);
459 // NativeCodeCall will never be at depth 0 so that we don't need to catch
460 // the root operation as emitOpMatch();
462 return;
465 // TODO(suderman): iterate through arguments, determine their types, output
466 // names.
467 SmallVector<std::string, 8> capture;
469 raw_indented_ostream::DelimitedScope scope(os);
471 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
472 std::string argName = formatv("arg{0}_{1}", depth, i);
473 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
474 if (argTree.isEither())
475 PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
476 if (argTree.isVariadic())
477 PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands");
479 os << "::mlir::Value " << argName << ";\n";
480 } else {
481 auto leaf = tree.getArgAsLeaf(i);
482 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
483 os << "::mlir::Attribute " << argName << ";\n";
484 } else {
485 os << "::mlir::Value " << argName << ";\n";
489 capture.push_back(std::move(argName));
492 auto tail = getTrailingDirectives(tree);
493 if (tail.returnType)
494 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
495 auto locToUse = getLocation(tail);
497 auto fmt = tree.getNativeCodeTemplate();
498 if (fmt.count("$_self") != 1)
499 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
500 "passing the defining Operation");
502 auto nativeCodeCall = std::string(
503 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
504 static_cast<ArrayRef<std::string>>(capture)));
506 emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall),
507 formatv("\"{0} return ::mlir::failure\"", nativeCodeCall));
509 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
510 auto name = tree.getArgName(i);
511 if (!name.empty() && name != "_") {
512 os << formatv("{0} = {1};\n", name, capture[i]);
516 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
517 std::string argName = capture[i];
519 // Handle nested DAG construct first
520 if (tree.getArgAsNestedDag(i)) {
521 PrintFatalError(
522 loc, formatv("Matching nested tree in NativeCodecall not support for "
523 "{0} as arg {1}",
524 argName, i));
527 DagLeaf leaf = tree.getArgAsLeaf(i);
529 // The parameter for native function doesn't bind any constraints.
530 if (leaf.isUnspecified())
531 continue;
533 auto constraint = leaf.getAsConstraint();
535 std::string self;
536 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
537 self = argName;
538 else
539 self = formatv("{0}.getType()", argName);
540 StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
541 emitStaticVerifierCall(
542 verifier, opName, self,
543 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
544 "constraint: "
545 "'{2}'\"",
546 i, tree.getNativeCodeTemplate(),
547 escapeString(constraint.getSummary()))
548 .str());
551 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
554 // Helper function to match patterns.
555 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
556 Operator &op = tree.getDialectOp(opMap);
557 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
558 << op.getOperationName() << "' at depth " << depth
559 << '\n');
561 auto getCastedName = [depth]() -> std::string {
562 return formatv("castedOp{0}", depth);
565 // The order of generating static matcher follows the topological order so
566 // that for every dependent DagNode already have their static matcher
567 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
568 // is when we are generating the static matcher for a DagNode itself. In this
569 // case, we need to emit the function body rather than a function call.
570 if (staticMatcherHelper.useStaticMatcher(tree) &&
571 !staticMatcherHelper.getMatcherName(tree).empty()) {
572 emitStaticMatchCall(tree, opName);
573 // In the codegen of rewriter, we suppose that castedOp0 will capture the
574 // root operation. Manually add it if the root DagNode is a static matcher.
575 if (depth == 0)
576 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
577 "(void){2};\n",
578 opName, op.getQualCppClassName(), getCastedName());
579 return;
582 std::string castedName = getCastedName();
583 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
584 "(void){0};\n",
585 castedName, opName, op.getQualCppClassName());
587 // Skip the operand matching at depth 0 as the pattern rewriter already does.
588 if (depth != 0)
589 emitMatchCheck(opName, /*matchStr=*/castedName,
590 formatv("\"{0} is not {1} type\"", castedName,
591 op.getQualCppClassName()));
593 // If the operand's name is set, set to that variable.
594 auto name = tree.getSymbol();
595 if (!name.empty())
596 os << formatv("{0} = {1};\n", name, castedName);
598 for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e;
599 ++i, ++opArgIdx) {
600 auto opArg = op.getArg(opArgIdx);
601 std::string argName = formatv("op{0}", depth + 1);
603 // Handle nested DAG construct first
604 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
605 if (argTree.isEither()) {
606 emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand,
607 depth);
608 ++opArgIdx;
609 continue;
611 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
612 if (argTree.isVariadic()) {
613 if (!operand->isVariadic()) {
614 auto error = formatv("variadic DAG construct can't match op {0}'s "
615 "non-variadic operand #{1}",
616 op.getOperationName(), opArgIdx);
617 PrintFatalError(loc, error);
619 emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx,
620 nextOperand, depth);
621 ++nextOperand;
622 continue;
624 if (operand->isVariableLength()) {
625 auto error = formatv("use nested DAG construct to match op {0}'s "
626 "variadic operand #{1} unsupported now",
627 op.getOperationName(), opArgIdx);
628 PrintFatalError(loc, error);
632 os << "{\n";
634 // Attributes don't count for getODSOperands.
635 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
636 os.indent() << formatv(
637 "auto *{0} = "
638 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
639 argName, castedName, nextOperand);
640 // Null check of operand's definingOp
641 emitMatchCheck(
642 castedName, /*matchStr=*/argName,
643 formatv("\"There's no operation that defines operand {0} of {1}\"",
644 nextOperand++, castedName));
645 emitMatch(argTree, argName, depth + 1);
646 os << formatv("tblgen_ops.push_back({0});\n", argName);
647 os.unindent() << "}\n";
648 continue;
651 // Next handle DAG leaf: operand or attribute
652 if (opArg.is<NamedTypeConstraint *>()) {
653 auto operandName =
654 formatv("{0}.getODSOperands({1})", castedName, nextOperand);
655 emitOperandMatch(tree, castedName, operandName.str(), opArgIdx,
656 /*operandMatcher=*/tree.getArgAsLeaf(i),
657 /*argName=*/tree.getArgName(i), opArgIdx,
658 /*variadicSubIndex=*/std::nullopt);
659 ++nextOperand;
660 } else if (opArg.is<NamedAttribute *>()) {
661 emitAttributeMatch(tree, opName, opArgIdx, depth);
662 } else {
663 PrintFatalError(loc, "unhandled case when matching op");
666 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
667 << op.getOperationName() << "' at depth " << depth
668 << '\n');
671 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
672 StringRef operandName, int operandIndex,
673 DagLeaf operandMatcher, StringRef argName,
674 int argIndex,
675 std::optional<int> variadicSubIndex) {
676 Operator &op = tree.getDialectOp(opMap);
677 auto *operand = op.getArg(operandIndex).get<NamedTypeConstraint *>();
679 // If a constraint is specified, we need to generate C++ statements to
680 // check the constraint.
681 if (!operandMatcher.isUnspecified()) {
682 if (!operandMatcher.isOperandMatcher())
683 PrintFatalError(
684 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
685 op.getOperationName(), argIndex + 1));
687 // Only need to verify if the matcher's type is different from the one
688 // of op definition.
689 Constraint constraint = operandMatcher.getAsConstraint();
690 if (operand->constraint != constraint) {
691 if (operand->isVariableLength()) {
692 auto error = formatv(
693 "further constrain op {0}'s variadic operand #{1} unsupported now",
694 op.getOperationName(), argIndex);
695 PrintFatalError(loc, error);
697 auto self = formatv("(*{0}.begin()).getType()", operandName);
698 StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
699 emitStaticVerifierCall(
700 verifier, opName, self.str(),
701 formatv(
702 "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
703 operand - op.operand_begin(), op.getOperationName(),
704 escapeString(constraint.getSummary()))
705 .str());
709 // Capture the value
710 // `$_` is a special symbol to ignore op argument matching.
711 if (!argName.empty() && argName != "_") {
712 auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex,
713 variadicSubIndex);
714 if (res == symbolInfoMap.end())
715 PrintFatalError(loc, formatv("symbol not found: {0}", argName));
717 os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
721 void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
722 StringRef opName, int argIndex,
723 int &operandIndex, int depth) {
724 constexpr int numEitherArgs = 2;
725 if (eitherArgTree.getNumArgs() != numEitherArgs)
726 PrintFatalError(loc, "`either` only supports grouping two operands");
728 Operator &op = tree.getDialectOp(opMap);
730 std::string codeBuffer;
731 llvm::raw_string_ostream tblgenOps(codeBuffer);
733 std::string lambda = formatv("eitherLambda{0}", depth);
734 os << formatv(
735 "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n",
736 lambda);
738 os.indent();
740 for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
741 if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
742 if (argTree.isEither())
743 PrintFatalError(loc, "either cannot be nested");
745 std::string argName = formatv("local_op_{0}", i).str();
747 os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
750 // Indent emitMatchCheck and emitMatch because they declare local
751 // variables.
752 os << "{\n";
753 os.indent();
755 emitMatchCheck(
756 opName, /*matchStr=*/argName,
757 formatv("\"There's no operation that defines operand {0} of {1}\"",
758 operandIndex++, opName));
759 emitMatch(argTree, argName, depth + 1);
761 os.unindent() << "}\n";
763 // `tblgen_ops` is used to collect the matched operations. In either, we
764 // need to queue the operation only if the matching success. Thus we emit
765 // the code at the end.
766 tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
767 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
768 emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
769 operandIndex,
770 /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
771 /*argName=*/eitherArgTree.getArgName(i), argIndex,
772 /*variadicSubIndex=*/std::nullopt);
773 ++operandIndex;
774 } else {
775 PrintFatalError(loc, "either can only be applied on operand");
779 os << tblgenOps.str();
780 os << "return ::mlir::success();\n";
781 os.unindent() << "};\n";
783 os << "{\n";
784 os.indent();
786 os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
787 operandIndex - 2);
788 os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
789 operandIndex - 1);
791 os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && "
792 "::mlir::failed({0}(eitherOperand1, "
793 "eitherOperand0)))\n",
794 lambda);
795 os.indent() << "return ::mlir::failure();\n";
797 os.unindent().unindent() << "}\n";
800 void PatternEmitter::emitVariadicOperandMatch(DagNode tree,
801 DagNode variadicArgTree,
802 StringRef opName, int argIndex,
803 int &operandIndex, int depth) {
804 Operator &op = tree.getDialectOp(opMap);
806 os << "{\n";
807 os.indent();
809 os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n",
810 opName, operandIndex);
811 os << formatv("if (variadic_operand_range.size() != {0}) "
812 "return ::mlir::failure();\n",
813 variadicArgTree.getNumArgs());
815 StringRef variadicTreeName = variadicArgTree.getSymbol();
816 if (!variadicTreeName.empty()) {
817 auto res =
818 symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex,
819 /*variadicSubIndex=*/std::nullopt);
820 if (res == symbolInfoMap.end())
821 PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName));
823 os << formatv("{0} = variadic_operand_range;\n",
824 res->second.getVarName(variadicTreeName));
827 for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) {
828 if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) {
829 if (!argTree.isOperation())
830 PrintFatalError(loc, "variadic only accepts operation sub-dags");
832 os << "{\n";
833 os.indent();
835 std::string argName = formatv("local_op_{0}", i).str();
836 os << formatv("auto *{0} = "
837 "variadic_operand_range[{1}].getDefiningOp();\n",
838 argName, i);
839 emitMatchCheck(
840 opName, /*matchStr=*/argName,
841 formatv("\"There's no operation that defines variadic operand "
842 "{0} (variadic sub-opearnd #{1}) of {2}\"",
843 operandIndex, i, opName));
844 emitMatch(argTree, argName, depth + 1);
845 os << formatv("tblgen_ops.push_back({0});\n", argName);
847 os.unindent() << "}\n";
848 } else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
849 auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i);
850 emitOperandMatch(tree, opName, operandName.str(), operandIndex,
851 /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i),
852 /*argName=*/variadicArgTree.getArgName(i), argIndex, i);
853 } else {
854 PrintFatalError(loc, "variadic can only be applied on operand");
858 os.unindent() << "}\n";
861 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
862 int argIndex, int depth) {
863 Operator &op = tree.getDialectOp(opMap);
864 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
865 const auto &attr = namedAttr->attr;
867 os << "{\n";
868 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
869 "(void)tblgen_attr;\n",
870 opName, attr.getStorageType(), namedAttr->name);
872 // TODO: This should use getter method to avoid duplication.
873 if (attr.hasDefaultValue()) {
874 os << "if (!tblgen_attr) tblgen_attr = "
875 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
876 attr.getDefaultValue()))
877 << ";\n";
878 } else if (attr.isOptional()) {
879 // For a missing attribute that is optional according to definition, we
880 // should just capture a mlir::Attribute() to signal the missing state.
881 // That is precisely what getDiscardableAttr() returns on missing
882 // attributes.
883 } else {
884 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
885 formatv("\"expected op '{0}' to have attribute '{1}' "
886 "of type '{2}'\"",
887 op.getOperationName(), namedAttr->name,
888 attr.getStorageType()));
891 auto matcher = tree.getArgAsLeaf(argIndex);
892 if (!matcher.isUnspecified()) {
893 if (!matcher.isAttrMatcher()) {
894 PrintFatalError(
895 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
896 op.getOperationName(), argIndex + 1));
899 // If a constraint is specified, we need to generate function call to its
900 // static verifier.
901 StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
902 if (attr.isOptional()) {
903 // Avoid dereferencing null attribute. This is using a simple heuristic to
904 // avoid common cases of attempting to dereference null attribute. This
905 // will return where there is no check if attribute is null unless the
906 // attribute's value is not used.
907 // FIXME: This could be improved as some null dereferences could slip
908 // through.
909 if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") &&
910 StringRef(matcher.getConditionTemplate()).contains("$_self")) {
911 os << "if (!tblgen_attr) return ::mlir::failure();\n";
914 emitStaticVerifierCall(
915 verifier, opName, "tblgen_attr",
916 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
917 "'{2}'\"",
918 op.getOperationName(), namedAttr->name,
919 escapeString(matcher.getAsConstraint().getSummary()))
920 .str());
923 // Capture the value
924 auto name = tree.getArgName(argIndex);
925 // `$_` is a special symbol to ignore op argument matching.
926 if (!name.empty() && name != "_") {
927 os << formatv("{0} = tblgen_attr;\n", name);
930 os.unindent() << "}\n";
933 void PatternEmitter::emitMatchCheck(
934 StringRef opName, const FmtObjectBase &matchFmt,
935 const llvm::formatv_object_base &failureFmt) {
936 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
939 void PatternEmitter::emitMatchCheck(StringRef opName,
940 const std::string &matchStr,
941 const std::string &failureStr) {
943 os << "if (!(" << matchStr << "))";
944 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
945 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
946 << failureStr << ";\n});";
949 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
950 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
951 int depth = 0;
952 emitMatch(tree, opName, depth);
954 for (auto &appliedConstraint : pattern.getConstraints()) {
955 auto &constraint = appliedConstraint.constraint;
956 auto &entities = appliedConstraint.entities;
958 auto condition = constraint.getConditionTemplate();
959 if (isa<TypeConstraint>(constraint)) {
960 if (entities.size() != 1)
961 PrintFatalError(loc, "type constraint requires exactly one argument");
963 auto self = formatv("({0}.getType())",
964 symbolInfoMap.getValueAndRangeUse(entities.front()));
965 emitMatchCheck(
966 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
967 formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"",
968 entities.front(), escapeString(constraint.getSummary())));
970 } else if (isa<AttrConstraint>(constraint)) {
971 PrintFatalError(
972 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
973 } else {
974 // TODO: replace formatv arguments with the exact specified
975 // args.
976 if (entities.size() > 4) {
977 PrintFatalError(loc, "only support up to 4-entity constraints now");
979 SmallVector<std::string, 4> names;
980 int i = 0;
981 for (int e = entities.size(); i < e; ++i)
982 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
983 std::string self = appliedConstraint.self;
984 if (!self.empty())
985 self = symbolInfoMap.getValueAndRangeUse(self);
986 for (; i < 4; ++i)
987 names.push_back("<unused>");
988 emitMatchCheck(opName,
989 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
990 names[1], names[2], names[3]),
991 formatv("\"entities '{0}' failed to satisfy constraint: "
992 "'{1}'\"",
993 llvm::join(entities, ", "),
994 escapeString(constraint.getSummary())));
998 // Some of the operands could be bound to the same symbol name, we need
999 // to enforce equality constraint on those.
1000 // TODO: we should be able to emit equality checks early
1001 // and short circuit unnecessary work if vars are not equal.
1002 for (auto symbolInfoIt = symbolInfoMap.begin();
1003 symbolInfoIt != symbolInfoMap.end();) {
1004 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
1005 auto startRange = range.first;
1006 auto endRange = range.second;
1008 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
1009 for (++startRange; startRange != endRange; ++startRange) {
1010 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
1011 emitMatchCheck(
1012 opName,
1013 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
1014 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
1015 secondOperand));
1018 symbolInfoIt = endRange;
1021 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
1024 void PatternEmitter::collectOps(DagNode tree,
1025 llvm::SmallPtrSetImpl<const Operator *> &ops) {
1026 // Check if this tree is an operation.
1027 if (tree.isOperation()) {
1028 const Operator &op = tree.getDialectOp(opMap);
1029 LLVM_DEBUG(llvm::dbgs()
1030 << "found operation " << op.getOperationName() << '\n');
1031 ops.insert(&op);
1034 // Recurse the arguments of the tree.
1035 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
1036 if (auto child = tree.getArgAsNestedDag(i))
1037 collectOps(child, ops);
1040 void PatternEmitter::emit(StringRef rewriteName) {
1041 // Get the DAG tree for the source pattern.
1042 DagNode sourceTree = pattern.getSourcePattern();
1044 const Operator &rootOp = pattern.getSourceRootOp();
1045 auto rootName = rootOp.getOperationName();
1047 // Collect the set of result operations.
1048 llvm::SmallPtrSet<const Operator *, 4> resultOps;
1049 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
1050 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
1051 collectOps(pattern.getResultPattern(i), resultOps);
1053 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
1055 // Emit RewritePattern for Pattern.
1056 auto locs = pattern.getLocation();
1057 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
1058 make_range(locs.rbegin(), locs.rend()));
1059 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
1060 {0}(::mlir::MLIRContext *context)
1061 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
1062 rewriteName, rootName, pattern.getBenefit());
1063 // Sort result operators by name.
1064 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
1065 resultOps.end());
1066 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
1067 return lhs->getOperationName() < rhs->getOperationName();
1069 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
1070 os << '"' << op->getOperationName() << '"';
1072 os << "}) {}\n";
1074 // Emit matchAndRewrite() function.
1076 auto classScope = os.scope();
1077 os.printReindented(R"(
1078 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
1079 ::mlir::PatternRewriter &rewriter) const override {)")
1080 << '\n';
1082 auto functionScope = os.scope();
1084 // Register all symbols bound in the source pattern.
1085 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
1087 LLVM_DEBUG(llvm::dbgs()
1088 << "start creating local variables for capturing matches\n");
1089 os << "// Variables for capturing values and attributes used while "
1090 "creating ops\n";
1091 // Create local variables for storing the arguments and results bound
1092 // to symbols.
1093 for (const auto &symbolInfoPair : symbolInfoMap) {
1094 const auto &symbol = symbolInfoPair.first;
1095 const auto &info = symbolInfoPair.second;
1097 os << info.getVarDecl(symbol);
1099 // TODO: capture ops with consistent numbering so that it can be
1100 // reused for fused loc.
1101 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
1102 LLVM_DEBUG(llvm::dbgs()
1103 << "done creating local variables for capturing matches\n");
1105 os << "// Match\n";
1106 os << "tblgen_ops.push_back(op0);\n";
1107 emitMatchLogic(sourceTree, "op0");
1109 os << "\n// Rewrite\n";
1110 emitRewriteLogic();
1112 os << "return ::mlir::success();\n";
1114 os << "};\n";
1116 os << "};\n\n";
1119 void PatternEmitter::emitRewriteLogic() {
1120 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
1121 const Operator &rootOp = pattern.getSourceRootOp();
1122 int numExpectedResults = rootOp.getNumResults();
1123 int numResultPatterns = pattern.getNumResultPatterns();
1125 // First register all symbols bound to ops generated in result patterns.
1126 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
1128 // Only the last N static values generated are used to replace the matched
1129 // root N-result op. We need to calculate the starting index (of the results
1130 // of the matched op) each result pattern is to replace.
1131 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
1132 // If we don't need to replace any value at all, set the replacement starting
1133 // index as the number of result patterns so we skip all of them when trying
1134 // to replace the matched op's results.
1135 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
1136 for (int i = numResultPatterns - 1; i >= 0; --i) {
1137 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
1138 offsets[i] = offsets[i + 1] - numValues;
1139 if (offsets[i] == 0) {
1140 if (replStartIndex == -1)
1141 replStartIndex = i;
1142 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
1143 auto error = formatv(
1144 "cannot use the same multi-result op '{0}' to generate both "
1145 "auxiliary values and values to be used for replacing the matched op",
1146 pattern.getResultPattern(i).getSymbol());
1147 PrintFatalError(loc, error);
1151 if (offsets.front() > 0) {
1152 const char error[] =
1153 "not enough values generated to replace the matched op";
1154 PrintFatalError(loc, error);
1157 os << "auto odsLoc = rewriter.getFusedLoc({";
1158 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
1159 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
1161 os << "}); (void)odsLoc;\n";
1163 // Process auxiliary result patterns.
1164 for (int i = 0; i < replStartIndex; ++i) {
1165 DagNode resultTree = pattern.getResultPattern(i);
1166 auto val = handleResultPattern(resultTree, offsets[i], 0);
1167 // Normal op creation will be streamed to `os` by the above call; but
1168 // NativeCodeCall will only be materialized to `os` if it is used. Here
1169 // we are handling auxiliary patterns so we want the side effect even if
1170 // NativeCodeCall is not replacing matched root op's results.
1171 if (resultTree.isNativeCodeCall() &&
1172 resultTree.getNumReturnsOfNativeCode() == 0)
1173 os << val << ";\n";
1176 auto processSupplementalPatterns = [&]() {
1177 int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1178 for (int i = 0, offset = -numSupplementalPatterns;
1179 i < numSupplementalPatterns; ++i) {
1180 DagNode resultTree = pattern.getSupplementalPattern(i);
1181 auto val = handleResultPattern(resultTree, offset++, 0);
1182 if (resultTree.isNativeCodeCall() &&
1183 resultTree.getNumReturnsOfNativeCode() == 0)
1184 os << val << ";\n";
1188 if (numExpectedResults == 0) {
1189 assert(replStartIndex >= numResultPatterns &&
1190 "invalid auxiliary vs. replacement pattern division!");
1191 processSupplementalPatterns();
1192 // No result to replace. Just erase the op.
1193 os << "rewriter.eraseOp(op0);\n";
1194 } else {
1195 // Process replacement result patterns.
1196 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
1197 for (int i = replStartIndex; i < numResultPatterns; ++i) {
1198 DagNode resultTree = pattern.getResultPattern(i);
1199 auto val = handleResultPattern(resultTree, offsets[i], 0);
1200 os << "\n";
1201 // Resolve each symbol for all range use so that we can loop over them.
1202 // We need an explicit cast to `SmallVector` to capture the cases where
1203 // `{0}` resolves to an `Operation::result_range` as well as cases that
1204 // are not iterable (e.g. vector that gets wrapped in additional braces by
1205 // RewriterGen).
1206 // TODO: Revisit the need for materializing a vector.
1207 os << symbolInfoMap.getAllRangeUse(
1208 val,
1209 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
1210 " tblgen_repl_values.push_back(v);\n}\n",
1211 "\n");
1213 processSupplementalPatterns();
1214 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
1217 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
1220 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
1221 return std::string(
1222 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
1225 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
1226 int resultIndex, int depth) {
1227 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
1228 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
1229 LLVM_DEBUG(llvm::dbgs() << '\n');
1231 if (resultTree.isLocationDirective()) {
1232 PrintFatalError(loc,
1233 "location directive can only be used with op creation");
1236 if (resultTree.isNativeCodeCall())
1237 return handleReplaceWithNativeCodeCall(resultTree, depth);
1239 if (resultTree.isReplaceWithValue())
1240 return handleReplaceWithValue(resultTree).str();
1242 // Normal op creation.
1243 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1244 if (resultTree.getSymbol().empty()) {
1245 // This is an op not explicitly bound to a symbol in the rewrite rule.
1246 // Register the auto-generated symbol for it.
1247 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1249 return symbol;
1252 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1253 assert(tree.isReplaceWithValue());
1255 if (tree.getNumArgs() != 1) {
1256 PrintFatalError(
1257 loc, "replaceWithValue directive must take exactly one argument");
1260 if (!tree.getSymbol().empty()) {
1261 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1264 return tree.getArgName(0);
1267 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1268 assert(tree.isLocationDirective());
1269 auto lookUpArgLoc = [this, &tree](int idx) {
1270 const auto *const lookupFmt = "{0}.getLoc()";
1271 return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt);
1274 if (tree.getNumArgs() == 0)
1275 llvm::PrintFatalError(
1276 "At least one argument to location directive required");
1278 if (!tree.getSymbol().empty())
1279 PrintFatalError(loc, "cannot bind symbol to location");
1281 if (tree.getNumArgs() == 1) {
1282 DagLeaf leaf = tree.getArgAsLeaf(0);
1283 if (leaf.isStringAttr())
1284 return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))",
1285 leaf.getStringAttr())
1286 .str();
1287 return lookUpArgLoc(0);
1290 std::string ret;
1291 llvm::raw_string_ostream os(ret);
1292 std::string strAttr;
1293 os << "rewriter.getFusedLoc({";
1294 bool first = true;
1295 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1296 DagLeaf leaf = tree.getArgAsLeaf(i);
1297 // Handle the optional string value.
1298 if (leaf.isStringAttr()) {
1299 if (!strAttr.empty())
1300 llvm::PrintFatalError("Only one string attribute may be specified");
1301 strAttr = leaf.getStringAttr();
1302 continue;
1304 os << (first ? "" : ", ") << lookUpArgLoc(i);
1305 first = false;
1307 os << "}";
1308 if (!strAttr.empty()) {
1309 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1311 os << ")";
1312 return os.str();
1315 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1316 int depth) {
1317 // Nested NativeCodeCall.
1318 if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1319 if (!dagNode.isNativeCodeCall())
1320 PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1321 "call");
1322 return handleReplaceWithNativeCodeCall(dagNode, depth);
1324 // String literal.
1325 auto dagLeaf = returnType.getArgAsLeaf(i);
1326 if (dagLeaf.isStringAttr())
1327 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1328 return tgfmt(
1329 "$0.getType()", &fmtCtx,
1330 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1333 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1334 StringRef patArgName) {
1335 if (leaf.isStringAttr())
1336 PrintFatalError(loc, "raw string not supported as argument");
1337 if (leaf.isConstantAttr()) {
1338 auto constAttr = leaf.getAsConstantAttr();
1339 return handleConstantAttr(constAttr.getAttribute(),
1340 constAttr.getConstantValue());
1342 if (leaf.isEnumAttrCase()) {
1343 auto enumCase = leaf.getAsEnumAttrCase();
1344 // This is an enum case backed by an IntegerAttr. We need to get its value
1345 // to build the constant.
1346 std::string val = std::to_string(enumCase.getValue());
1347 return handleConstantAttr(enumCase, val);
1350 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1351 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1352 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1353 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1354 << "' (via symbol ref)\n");
1355 return argName;
1357 if (leaf.isNativeCodeCall()) {
1358 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1359 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1360 << "' (via NativeCodeCall)\n");
1361 return std::string(repl);
1363 PrintFatalError(loc, "unhandled case when rewriting op");
1366 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1367 int depth) {
1368 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1369 LLVM_DEBUG(tree.print(llvm::dbgs()));
1370 LLVM_DEBUG(llvm::dbgs() << '\n');
1372 auto fmt = tree.getNativeCodeTemplate();
1374 SmallVector<std::string, 16> attrs;
1376 auto tail = getTrailingDirectives(tree);
1377 if (tail.returnType)
1378 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1379 auto locToUse = getLocation(tail);
1381 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1382 if (tree.isNestedDagArg(i)) {
1383 attrs.push_back(
1384 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1385 } else {
1386 attrs.push_back(
1387 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1389 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1390 << " replacement: " << attrs[i] << "\n");
1393 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1394 static_cast<ArrayRef<std::string>>(attrs));
1396 // In general, NativeCodeCall without naming binding don't need this. To
1397 // ensure void helper function has been correctly labeled, i.e., use
1398 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1399 // get a compilation error in the auto-generated file.
1400 // Example.
1401 // // In the td file
1402 // Pat<(...), (NativeCodeCall<Foo> ...)>
1404 // ---
1406 // // In the auto-generated .cpp
1407 // ...
1408 // // Causes compilation error if Foo() returns void.
1409 // auto nativeVar = Foo();
1410 // ...
1411 if (tree.getNumReturnsOfNativeCode() != 0) {
1412 // Determine the local variable name for return value.
1413 std::string varName =
1414 SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1415 if (varName.empty()) {
1416 varName = formatv("nativeVar_{0}", nextValueId++);
1417 // Register the local variable for later uses.
1418 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1421 // Catch the return value of helper function.
1422 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1424 if (!tree.getSymbol().empty())
1425 symbol = tree.getSymbol().str();
1426 else
1427 symbol = varName;
1430 return symbol;
1433 int PatternEmitter::getNodeValueCount(DagNode node) {
1434 if (node.isOperation()) {
1435 // If the op is bound to a symbol in the rewrite rule, query its result
1436 // count from the symbol info map.
1437 auto symbol = node.getSymbol();
1438 if (!symbol.empty()) {
1439 return symbolInfoMap.getStaticValueCount(symbol);
1441 // Otherwise this is an unbound op; we will use all its results.
1442 return pattern.getDialectOp(node).getNumResults();
1445 if (node.isNativeCodeCall())
1446 return node.getNumReturnsOfNativeCode();
1448 return 1;
1451 PatternEmitter::TrailingDirectives
1452 PatternEmitter::getTrailingDirectives(DagNode tree) {
1453 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1455 // Look backwards through the arguments.
1456 auto numPatArgs = tree.getNumArgs();
1457 for (int i = numPatArgs - 1; i >= 0; --i) {
1458 auto dagArg = tree.getArgAsNestedDag(i);
1459 // A leaf is not a directive. Stop looking.
1460 if (!dagArg)
1461 break;
1463 auto isLocation = dagArg.isLocationDirective();
1464 auto isReturnType = dagArg.isReturnTypeDirective();
1465 // If encountered a DAG node that isn't a trailing directive, stop looking.
1466 if (!(isLocation || isReturnType))
1467 break;
1468 // Save the directive, but error if one of the same type was already
1469 // found.
1470 ++tail.numDirectives;
1471 if (isLocation) {
1472 if (tail.location)
1473 PrintFatalError(loc, "`location` directive can only be specified "
1474 "once");
1475 tail.location = dagArg;
1476 } else if (isReturnType) {
1477 if (tail.returnType)
1478 PrintFatalError(loc, "`returnType` directive can only be specified "
1479 "once");
1480 tail.returnType = dagArg;
1484 return tail;
1487 std::string
1488 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1489 if (tail.location)
1490 return handleLocationDirective(tail.location);
1492 // If no explicit location is given, use the default, all fused, location.
1493 return "odsLoc";
1496 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1497 int depth) {
1498 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1499 LLVM_DEBUG(tree.print(llvm::dbgs()));
1500 LLVM_DEBUG(llvm::dbgs() << '\n');
1502 Operator &resultOp = tree.getDialectOp(opMap);
1503 auto numOpArgs = resultOp.getNumArgs();
1504 auto numPatArgs = tree.getNumArgs();
1506 auto tail = getTrailingDirectives(tree);
1507 auto locToUse = getLocation(tail);
1509 auto inPattern = numPatArgs - tail.numDirectives;
1510 if (numOpArgs != inPattern) {
1511 PrintFatalError(loc,
1512 formatv("resultant op '{0}' argument number mismatch: "
1513 "{1} in pattern vs. {2} in definition",
1514 resultOp.getOperationName(), inPattern, numOpArgs));
1517 // A map to collect all nested DAG child nodes' names, with operand index as
1518 // the key. This includes both bound and unbound child nodes.
1519 ChildNodeIndexNameMap childNodeNames;
1521 // First go through all the child nodes who are nested DAG constructs to
1522 // create ops for them and remember the symbol names for them, so that we can
1523 // use the results in the current node. This happens in a recursive manner.
1524 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1525 if (auto child = tree.getArgAsNestedDag(i))
1526 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1529 // The name of the local variable holding this op.
1530 std::string valuePackName;
1531 // The symbol for holding the result of this pattern. Note that the result of
1532 // this pattern is not necessarily the same as the variable created by this
1533 // pattern because we can use `__N` suffix to refer only a specific result if
1534 // the generated op is a multi-result op.
1535 std::string resultValue;
1536 if (tree.getSymbol().empty()) {
1537 // No symbol is explicitly bound to this op in the pattern. Generate a
1538 // unique name.
1539 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1540 } else {
1541 resultValue = std::string(tree.getSymbol());
1542 // Strip the index to get the name for the value pack and use it to name the
1543 // local variable for the op.
1544 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1547 // Create the local variable for this op.
1548 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1549 valuePackName);
1551 // Right now ODS don't have general type inference support. Except a few
1552 // special cases listed below, DRR needs to supply types for all results
1553 // when building an op.
1554 bool isSameOperandsAndResultType =
1555 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1556 bool useFirstAttr =
1557 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1559 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1560 // We know how to deduce the result type for ops with these traits and we've
1561 // generated builders taking aggregate parameters. Use those builders to
1562 // create the ops.
1564 // First prepare local variables for op arguments used in builder call.
1565 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1567 // Then create the op.
1568 os.scope("", "\n}\n").os << formatv(
1569 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1570 valuePackName, resultOp.getQualCppClassName(), locToUse);
1571 return resultValue;
1574 bool usePartialResults = valuePackName != resultValue;
1576 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1577 // For these cases (broadcastable ops, op results used both as auxiliary
1578 // values and replacement values, ops in nested patterns, auxiliary ops), we
1579 // still need to supply the result types when building the op. But because
1580 // we don't generate a builder automatically with ODS for them, it's the
1581 // developer's responsibility to make sure such a builder (with result type
1582 // deduction ability) exists. We go through the separate-parameter builder
1583 // here given that it's easier for developers to write compared to
1584 // aggregate-parameter builders.
1585 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1587 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1588 resultOp.getQualCppClassName(), locToUse);
1589 supplyValuesForOpArgs(tree, childNodeNames, depth);
1590 os << "\n );\n}\n";
1591 return resultValue;
1594 // If we are provided explicit return types, use them to build the op.
1595 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1596 // the values generated from the source pattern root op. Then we must use the
1597 // source pattern's value types to determine the value type of the generated
1598 // op here.
1599 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1600 PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1601 "return values replace the source pattern's root op");
1603 // First prepare local variables for op arguments used in builder call.
1604 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1606 // Then prepare the result types. We need to specify the types for all
1607 // results.
1608 os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; "
1609 "(void)tblgen_types;\n");
1610 int numResults = resultOp.getNumResults();
1611 if (tail.returnType) {
1612 auto numRetTys = tail.returnType.getNumArgs();
1613 for (int i = 0; i < numRetTys; ++i) {
1614 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1615 os << "tblgen_types.push_back(" << varName << ");\n";
1617 } else {
1618 if (numResults != 0) {
1619 // Copy the result types from the source pattern.
1620 for (int i = 0; i < numResults; ++i)
1621 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1622 " tblgen_types.push_back(v.getType());\n}\n",
1623 resultIndex + i);
1626 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1627 "tblgen_values, tblgen_attrs);\n",
1628 valuePackName, resultOp.getQualCppClassName(), locToUse);
1629 os.unindent() << "}\n";
1630 return resultValue;
1633 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1634 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1635 Operator &resultOp = node.getDialectOp(opMap);
1637 // Now prepare operands used for building this op:
1638 // * If the operand is non-variadic, we create a `Value` local variable.
1639 // * If the operand is variadic, we create a `SmallVector<Value>` local
1640 // variable.
1642 int valueIndex = 0; // An index for uniquing local variable names.
1643 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1644 const auto *operand =
1645 llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex));
1646 // We do not need special handling for attributes.
1647 if (!operand)
1648 continue;
1650 raw_indented_ostream::DelimitedScope scope(os);
1651 std::string varName;
1652 if (operand->isVariadic()) {
1653 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1654 os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName);
1655 std::string range;
1656 if (node.isNestedDagArg(argIndex)) {
1657 range = childNodeNames[argIndex];
1658 } else {
1659 range = std::string(node.getArgName(argIndex));
1661 // Resolve the symbol for all range use so that we have a uniform way of
1662 // capturing the values.
1663 range = symbolInfoMap.getValueAndRangeUse(range);
1664 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1665 varName);
1666 } else {
1667 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1668 os << formatv("::mlir::Value {0} = ", varName);
1669 if (node.isNestedDagArg(argIndex)) {
1670 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1671 } else {
1672 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1673 auto symbol =
1674 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1675 if (leaf.isNativeCodeCall()) {
1676 os << std::string(
1677 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1678 } else {
1679 os << symbol;
1682 os << ";\n";
1685 // Update to use the newly created local variable for building the op later.
1686 childNodeNames[argIndex] = varName;
1690 void PatternEmitter::supplyValuesForOpArgs(
1691 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1692 Operator &resultOp = node.getDialectOp(opMap);
1693 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1694 argIndex != numOpArgs; ++argIndex) {
1695 // Start each argument on its own line.
1696 os << ",\n ";
1698 Argument opArg = resultOp.getArg(argIndex);
1699 // Handle the case of operand first.
1700 if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) {
1701 if (!operand->name.empty())
1702 os << "/*" << operand->name << "=*/";
1703 os << childNodeNames.lookup(argIndex);
1704 continue;
1707 // The argument in the op definition.
1708 auto opArgName = resultOp.getArgName(argIndex);
1709 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1710 if (!subTree.isNativeCodeCall())
1711 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1712 "for creating attribute");
1713 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1714 } else {
1715 auto leaf = node.getArgAsLeaf(argIndex);
1716 // The argument in the result DAG pattern.
1717 auto patArgName = node.getArgName(argIndex);
1718 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1719 // TODO: Refactor out into map to avoid recomputing these.
1720 if (!opArg.is<NamedAttribute *>())
1721 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1722 if (!patArgName.empty())
1723 os << "/*" << patArgName << "=*/";
1724 } else {
1725 os << "/*" << opArgName << "=*/";
1727 os << handleOpArgument(leaf, patArgName);
1732 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1733 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1734 Operator &resultOp = node.getDialectOp(opMap);
1736 auto scope = os.scope();
1737 os << formatv("::llvm::SmallVector<::mlir::Value, 4> "
1738 "tblgen_values; (void)tblgen_values;\n");
1739 os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> "
1740 "tblgen_attrs; (void)tblgen_attrs;\n");
1742 const char *addAttrCmd =
1743 "if (auto tmpAttr = {1}) {\n"
1744 " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), "
1745 "tmpAttr);\n}\n";
1746 int numVariadic = 0;
1747 bool hasOperandSegmentSizes = false;
1748 std::vector<std::string> sizes;
1749 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1750 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1751 // The argument in the op definition.
1752 auto opArgName = resultOp.getArgName(argIndex);
1753 hasOperandSegmentSizes =
1754 hasOperandSegmentSizes || opArgName == "operandSegmentSizes";
1755 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1756 if (!subTree.isNativeCodeCall())
1757 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1758 "for creating attribute");
1759 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1760 } else {
1761 auto leaf = node.getArgAsLeaf(argIndex);
1762 // The argument in the result DAG pattern.
1763 auto patArgName = node.getArgName(argIndex);
1764 os << formatv(addAttrCmd, opArgName,
1765 handleOpArgument(leaf, patArgName));
1767 continue;
1770 const auto *operand =
1771 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1772 std::string varName;
1773 if (operand->isVariadic()) {
1774 ++numVariadic;
1775 std::string range;
1776 if (node.isNestedDagArg(argIndex)) {
1777 range = childNodeNames.lookup(argIndex);
1778 } else {
1779 range = std::string(node.getArgName(argIndex));
1781 // Resolve the symbol for all range use so that we have a uniform way of
1782 // capturing the values.
1783 range = symbolInfoMap.getValueAndRangeUse(range);
1784 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1785 range);
1786 sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range));
1787 } else {
1788 sizes.push_back("1");
1789 os << formatv("tblgen_values.push_back(");
1790 if (node.isNestedDagArg(argIndex)) {
1791 os << symbolInfoMap.getValueAndRangeUse(
1792 childNodeNames.lookup(argIndex));
1793 } else {
1794 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1795 if (leaf.isConstantAttr())
1796 // TODO: Use better location
1797 PrintFatalError(
1798 loc,
1799 "attribute found where value was expected, if attempting to use "
1800 "constant value, construct a constant op with given attribute "
1801 "instead");
1803 auto symbol =
1804 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1805 if (leaf.isNativeCodeCall()) {
1806 os << std::string(
1807 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1808 } else {
1809 os << symbol;
1812 os << ");\n";
1816 if (numVariadic > 1 && !hasOperandSegmentSizes) {
1817 // Only set size if it can't be computed.
1818 const auto *sameVariadicSize =
1819 resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1820 if (!sameVariadicSize) {
1821 const char *setSizes = R"(
1822 tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"),
1823 rewriter.getDenseI32ArrayAttr({{ {0} }));
1825 os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str());
1830 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
1831 const RecordKeeper &recordKeeper,
1832 RecordOperatorMap &mapper)
1833 : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
1835 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1836 // PatternEmitter will use the static matcher if there's one generated. To
1837 // ensure that all the dependent static matchers are generated before emitting
1838 // the matching logic of the DagNode, we use topological order to achieve it.
1839 for (auto &dagInfo : topologicalOrder) {
1840 DagNode node = dagInfo.first;
1841 if (!useStaticMatcher(node))
1842 continue;
1844 std::string funcName =
1845 formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1846 assert(!matcherNames.contains(node));
1847 PatternEmitter(dagInfo.second, &opMap, os, *this)
1848 .emitStaticMatcher(node, funcName);
1849 matcherNames[node] = funcName;
1853 void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
1854 staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef());
1857 void StaticMatcherHelper::addPattern(Record *record) {
1858 Pattern pat(record, &opMap);
1860 // While generating the function body of the DAG matcher, it may depends on
1861 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1862 // the topological order for all the DAGs and emit the DAG matchers in this
1863 // order.
1864 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1865 ++refStats[node];
1867 if (refStats[node] != 1)
1868 return;
1870 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1871 if (DagNode sibling = node.getArgAsNestedDag(i))
1872 dfs(sibling);
1873 else {
1874 DagLeaf leaf = node.getArgAsLeaf(i);
1875 if (!leaf.isUnspecified())
1876 constraints.insert(leaf);
1879 topologicalOrder.push_back(std::make_pair(node, record));
1882 dfs(pat.getSourcePattern());
1885 StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
1886 if (leaf.isAttrMatcher()) {
1887 std::optional<StringRef> constraint =
1888 staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
1889 assert(constraint && "attribute constraint was not uniqued");
1890 return *constraint;
1892 assert(leaf.isOperandMatcher());
1893 return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
1896 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1897 emitSourceFileHeader("Rewriters", os, recordKeeper);
1899 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1901 // We put the map here because it can be shared among multiple patterns.
1902 RecordOperatorMap recordOpMap;
1904 // Exam all the patterns and generate static matcher for the duplicated
1905 // DagNode.
1906 StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
1907 for (Record *p : patterns)
1908 staticMatcher.addPattern(p);
1909 staticMatcher.populateStaticConstraintFunctions(os);
1910 staticMatcher.populateStaticMatchers(os);
1912 std::vector<std::string> rewriterNames;
1913 rewriterNames.reserve(patterns.size());
1915 std::string baseRewriterName = "GeneratedConvert";
1916 int rewriterIndex = 0;
1918 for (Record *p : patterns) {
1919 std::string name;
1920 if (p->isAnonymous()) {
1921 // If no name is provided, ensure unique rewriter names simply by
1922 // appending unique suffix.
1923 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1924 } else {
1925 name = std::string(p->getName());
1927 LLVM_DEBUG(llvm::dbgs()
1928 << "=== start generating pattern '" << name << "' ===\n");
1929 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1930 LLVM_DEBUG(llvm::dbgs()
1931 << "=== done generating pattern '" << name << "' ===\n");
1932 rewriterNames.push_back(std::move(name));
1935 // Emit function to add the generated matchers to the pattern list.
1936 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1937 "::mlir::RewritePatternSet &patterns) {\n";
1938 for (const auto &name : rewriterNames) {
1939 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1941 os << "}\n";
1944 static mlir::GenRegistration
1945 genRewriters("gen-rewriters", "Generate pattern rewriters",
1946 [](const RecordKeeper &records, raw_ostream &os) {
1947 emitRewriters(records, os);
1948 return false;