1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
12 //===----------------------------------------------------------------------===//
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
27 using namespace tblgen
;
31 //===----------------------------------------------------------------------===//
33 //===----------------------------------------------------------------------===//
35 bool DagLeaf::isUnspecified() const {
36 return isa_and_nonnull
<llvm::UnsetInit
>(def
);
39 bool DagLeaf::isOperandMatcher() const {
40 // Operand matchers specify a type constraint.
41 return isSubClassOf("TypeConstraint");
44 bool DagLeaf::isAttrMatcher() const {
45 // Attribute matchers specify an attribute constraint.
46 return isSubClassOf("AttrConstraint");
49 bool DagLeaf::isNativeCodeCall() const {
50 return isSubClassOf("NativeCodeCall");
53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
55 bool DagLeaf::isEnumAttrCase() const {
56 return isSubClassOf("EnumAttrCaseInfo");
59 bool DagLeaf::isStringAttr() const { return isa
<llvm::StringInit
>(def
); }
61 Constraint
DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast
<llvm::DefInit
>(def
)->getDef());
67 ConstantAttr
DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast
<llvm::DefInit
>(def
));
72 EnumAttrCase
DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast
<llvm::DefInit
>(def
));
77 std::string
DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
81 llvm::StringRef
DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast
<llvm::DefInit
>(def
)->getDef()->getValueAsString("expression");
86 int DagLeaf::getNumReturnsOfNativeCode() const {
87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88 return cast
<llvm::DefInit
>(def
)->getDef()->getValueAsInt("numReturns");
91 std::string
DagLeaf::getStringAttr() const {
92 assert(isStringAttr() && "the DAG leaf must be string attribute");
93 return def
->getAsUnquotedString();
95 bool DagLeaf::isSubClassOf(StringRef superclass
) const {
96 if (auto *defInit
= dyn_cast_or_null
<llvm::DefInit
>(def
))
97 return defInit
->getDef()->isSubClassOf(superclass
);
101 void DagLeaf::print(raw_ostream
&os
) const {
106 //===----------------------------------------------------------------------===//
108 //===----------------------------------------------------------------------===//
110 bool DagNode::isNativeCodeCall() const {
111 if (auto *defInit
= dyn_cast_or_null
<llvm::DefInit
>(node
->getOperator()))
112 return defInit
->getDef()->isSubClassOf("NativeCodeCall");
116 bool DagNode::isOperation() const {
117 return !isNativeCodeCall() && !isReplaceWithValue() &&
118 !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
122 llvm::StringRef
DagNode::getNativeCodeTemplate() const {
123 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
124 return cast
<llvm::DefInit
>(node
->getOperator())
126 ->getValueAsString("expression");
129 int DagNode::getNumReturnsOfNativeCode() const {
130 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
131 return cast
<llvm::DefInit
>(node
->getOperator())
133 ->getValueAsInt("numReturns");
136 llvm::StringRef
DagNode::getSymbol() const { return node
->getNameStr(); }
138 Operator
&DagNode::getDialectOp(RecordOperatorMap
*mapper
) const {
139 llvm::Record
*opDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
140 auto it
= mapper
->find(opDef
);
141 if (it
!= mapper
->end())
143 return *mapper
->try_emplace(opDef
, std::make_unique
<Operator
>(opDef
))
147 int DagNode::getNumOps() const {
148 // We want to get number of operations recursively involved in the DAG tree.
149 // All other directives should be excluded.
150 int count
= isOperation() ? 1 : 0;
151 for (int i
= 0, e
= getNumArgs(); i
!= e
; ++i
) {
152 if (auto child
= getArgAsNestedDag(i
))
153 count
+= child
.getNumOps();
158 int DagNode::getNumArgs() const { return node
->getNumArgs(); }
160 bool DagNode::isNestedDagArg(unsigned index
) const {
161 return isa
<llvm::DagInit
>(node
->getArg(index
));
164 DagNode
DagNode::getArgAsNestedDag(unsigned index
) const {
165 return DagNode(dyn_cast_or_null
<llvm::DagInit
>(node
->getArg(index
)));
168 DagLeaf
DagNode::getArgAsLeaf(unsigned index
) const {
169 assert(!isNestedDagArg(index
));
170 return DagLeaf(node
->getArg(index
));
173 StringRef
DagNode::getArgName(unsigned index
) const {
174 return node
->getArgNameStr(index
);
177 bool DagNode::isReplaceWithValue() const {
178 auto *dagOpDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
179 return dagOpDef
->getName() == "replaceWithValue";
182 bool DagNode::isLocationDirective() const {
183 auto *dagOpDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
184 return dagOpDef
->getName() == "location";
187 bool DagNode::isReturnTypeDirective() const {
188 auto *dagOpDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
189 return dagOpDef
->getName() == "returnType";
192 bool DagNode::isEither() const {
193 auto *dagOpDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
194 return dagOpDef
->getName() == "either";
197 bool DagNode::isVariadic() const {
198 auto *dagOpDef
= cast
<llvm::DefInit
>(node
->getOperator())->getDef();
199 return dagOpDef
->getName() == "variadic";
202 void DagNode::print(raw_ostream
&os
) const {
207 //===----------------------------------------------------------------------===//
209 //===----------------------------------------------------------------------===//
211 StringRef
SymbolInfoMap::getValuePackName(StringRef symbol
, int *index
) {
213 auto [name
, indexStr
] = symbol
.rsplit("__");
215 if (indexStr
.consumeInteger(10, idx
)) {
216 // The second part is not an index; we return the whole symbol as-is.
225 SymbolInfoMap::SymbolInfo::SymbolInfo(
226 const Operator
*op
, SymbolInfo::Kind kind
,
227 std::optional
<DagAndConstant
> dagAndConstant
)
228 : op(op
), kind(kind
), dagAndConstant(dagAndConstant
) {}
230 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
237 return op
->getNumResults();
238 case Kind::MultipleValues
:
241 llvm_unreachable("unknown kind");
244 std::string
SymbolInfoMap::SymbolInfo::getVarName(StringRef name
) const {
245 return alternativeName
? *alternativeName
: name
.str();
248 std::string
SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name
) const {
249 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name
<< "': ");
253 return op
->getArg(getArgIndex())
254 .get
<NamedAttribute
*>()
255 ->attr
.getStorageType()
257 // TODO(suderman): Use a more exact type when available.
258 return "::mlir::Attribute";
260 case Kind::Operand
: {
261 // Use operand range for captured operands (to support potential variadic
263 return "::mlir::Operation::operand_range";
266 return "::mlir::Value";
268 case Kind::MultipleValues
: {
269 return "::mlir::ValueRange";
272 // Use the op itself for captured results.
273 return op
->getQualCppClassName();
276 llvm_unreachable("unknown kind");
279 std::string
SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name
) const {
280 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name
<< "': ");
281 std::string varInit
= kind
== Kind::Operand
? "(op0->getOperands())" : "";
283 formatv("{0} {1}{2};\n", getVarTypeStr(name
), getVarName(name
), varInit
));
286 std::string
SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name
) const {
287 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name
<< "': ");
289 formatv("{0} &{1}", getVarTypeStr(name
), getVarName(name
)));
292 std::string
SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293 StringRef name
, int index
, const char *fmt
, const char *separator
) const {
294 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name
<< "': ");
298 auto repl
= formatv(fmt
, name
);
299 LLVM_DEBUG(llvm::dbgs() << repl
<< " (Attr)\n");
300 return std::string(repl
);
302 case Kind::Operand
: {
304 auto *operand
= op
->getArg(getArgIndex()).get
<NamedTypeConstraint
*>();
305 // If this operand is variadic and this SymbolInfo doesn't have a range
306 // index, then return the full variadic operand_range. Otherwise, return
308 if (operand
->isVariableLength() && !getVariadicSubIndex().has_value()) {
309 auto repl
= formatv(fmt
, name
);
310 LLVM_DEBUG(llvm::dbgs() << repl
<< " (VariadicOperand)\n");
311 return std::string(repl
);
313 auto repl
= formatv(fmt
, formatv("(*{0}.begin())", name
));
314 LLVM_DEBUG(llvm::dbgs() << repl
<< " (SingleOperand)\n");
315 return std::string(repl
);
318 // If `index` is greater than zero, then we are referencing a specific
319 // result of a multi-result op. The result can still be variadic.
322 std::string(formatv("{0}.getODSResults({1})", name
, index
));
323 if (!op
->getResult(index
).isVariadic())
324 v
= std::string(formatv("(*{0}.begin())", v
));
325 auto repl
= formatv(fmt
, v
);
326 LLVM_DEBUG(llvm::dbgs() << repl
<< " (SingleResult)\n");
327 return std::string(repl
);
330 // If this op has no result at all but still we bind a symbol to it, it
331 // means we want to capture the op itself.
332 if (op
->getNumResults() == 0) {
333 LLVM_DEBUG(llvm::dbgs() << name
<< " (Op)\n");
334 return formatv(fmt
, name
);
337 // We are referencing all results of the multi-result op. A specific result
338 // can either be a value or a range. Then join them with `separator`.
339 SmallVector
<std::string
, 4> values
;
340 values
.reserve(op
->getNumResults());
342 for (int i
= 0, e
= op
->getNumResults(); i
< e
; ++i
) {
343 std::string v
= std::string(formatv("{0}.getODSResults({1})", name
, i
));
344 if (!op
->getResult(i
).isVariadic()) {
345 v
= std::string(formatv("(*{0}.begin())", v
));
347 values
.push_back(std::string(formatv(fmt
, v
)));
349 auto repl
= llvm::join(values
, separator
);
350 LLVM_DEBUG(llvm::dbgs() << repl
<< " (VariadicResult)\n");
355 assert(op
== nullptr);
356 auto repl
= formatv(fmt
, name
);
357 LLVM_DEBUG(llvm::dbgs() << repl
<< " (Value)\n");
358 return std::string(repl
);
360 case Kind::MultipleValues
: {
361 assert(op
== nullptr);
362 assert(index
< getSize());
365 formatv(fmt
, std::string(formatv("{0}[{1}]", name
, index
)));
366 LLVM_DEBUG(llvm::dbgs() << repl
<< " (MultipleValues)\n");
369 // If it doesn't specify certain element, unpack them all.
371 formatv(fmt
, std::string(formatv("{0}.begin(), {0}.end()", name
)));
372 LLVM_DEBUG(llvm::dbgs() << repl
<< " (MultipleValues)\n");
373 return std::string(repl
);
376 llvm_unreachable("unknown kind");
379 std::string
SymbolInfoMap::SymbolInfo::getAllRangeUse(
380 StringRef name
, int index
, const char *fmt
, const char *separator
) const {
381 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name
<< "': ");
384 case Kind::Operand
: {
385 assert(index
< 0 && "only allowed for symbol bound to result");
386 auto repl
= formatv(fmt
, name
);
387 LLVM_DEBUG(llvm::dbgs() << repl
<< " (Operand/Attr)\n");
388 return std::string(repl
);
392 auto repl
= formatv(fmt
, formatv("{0}.getODSResults({1})", name
, index
));
393 LLVM_DEBUG(llvm::dbgs() << repl
<< " (SingleResult)\n");
394 return std::string(repl
);
397 // We are referencing all results of the multi-result op. Each result should
398 // have a value range, and then join them with `separator`.
399 SmallVector
<std::string
, 4> values
;
400 values
.reserve(op
->getNumResults());
402 for (int i
= 0, e
= op
->getNumResults(); i
< e
; ++i
) {
403 values
.push_back(std::string(
404 formatv(fmt
, formatv("{0}.getODSResults({1})", name
, i
))));
406 auto repl
= llvm::join(values
, separator
);
407 LLVM_DEBUG(llvm::dbgs() << repl
<< " (VariadicResult)\n");
411 assert(index
< 0 && "only allowed for symbol bound to result");
412 assert(op
== nullptr);
413 auto repl
= formatv(fmt
, formatv("{{{0}}", name
));
414 LLVM_DEBUG(llvm::dbgs() << repl
<< " (Value)\n");
415 return std::string(repl
);
417 case Kind::MultipleValues
: {
418 assert(op
== nullptr);
419 assert(index
< getSize());
422 formatv(fmt
, std::string(formatv("{0}[{1}]", name
, index
)));
423 LLVM_DEBUG(llvm::dbgs() << repl
<< " (MultipleValues)\n");
427 formatv(fmt
, std::string(formatv("{0}.begin(), {0}.end()", name
)));
428 LLVM_DEBUG(llvm::dbgs() << repl
<< " (MultipleValues)\n");
429 return std::string(repl
);
432 llvm_unreachable("unknown kind");
435 bool SymbolInfoMap::bindOpArgument(DagNode node
, StringRef symbol
,
436 const Operator
&op
, int argIndex
,
437 std::optional
<int> variadicSubIndex
) {
438 StringRef name
= getValuePackName(symbol
);
439 if (name
!= symbol
) {
440 auto error
= formatv(
441 "symbol '{0}' with trailing index cannot bind to op argument", symbol
);
442 PrintFatalError(loc
, error
);
446 op
.getArg(argIndex
).is
<NamedAttribute
*>()
447 ? SymbolInfo::getAttr(&op
, argIndex
)
448 : SymbolInfo::getOperand(node
, &op
, argIndex
, variadicSubIndex
);
450 std::string key
= symbol
.str();
451 if (symbolInfoMap
.count(key
)) {
452 // Only non unique name for the operand is supported.
453 if (symInfo
.kind
!= SymbolInfo::Kind::Operand
) {
457 // Cannot add new operand if there is already non operand with the same
459 if (symbolInfoMap
.find(key
)->second
.kind
!= SymbolInfo::Kind::Operand
) {
464 symbolInfoMap
.emplace(key
, symInfo
);
468 bool SymbolInfoMap::bindOpResult(StringRef symbol
, const Operator
&op
) {
469 std::string name
= getValuePackName(symbol
).str();
470 auto inserted
= symbolInfoMap
.emplace(name
, SymbolInfo::getResult(&op
));
472 return symbolInfoMap
.count(inserted
->first
) == 1;
475 bool SymbolInfoMap::bindValues(StringRef symbol
, int numValues
) {
476 std::string name
= getValuePackName(symbol
).str();
478 return bindMultipleValues(name
, numValues
);
479 return bindValue(name
);
482 bool SymbolInfoMap::bindValue(StringRef symbol
) {
483 auto inserted
= symbolInfoMap
.emplace(symbol
.str(), SymbolInfo::getValue());
484 return symbolInfoMap
.count(inserted
->first
) == 1;
487 bool SymbolInfoMap::bindMultipleValues(StringRef symbol
, int numValues
) {
488 std::string name
= getValuePackName(symbol
).str();
490 symbolInfoMap
.emplace(name
, SymbolInfo::getMultipleValues(numValues
));
491 return symbolInfoMap
.count(inserted
->first
) == 1;
494 bool SymbolInfoMap::bindAttr(StringRef symbol
) {
495 auto inserted
= symbolInfoMap
.emplace(symbol
.str(), SymbolInfo::getAttr());
496 return symbolInfoMap
.count(inserted
->first
) == 1;
499 bool SymbolInfoMap::contains(StringRef symbol
) const {
500 return find(symbol
) != symbolInfoMap
.end();
503 SymbolInfoMap::const_iterator
SymbolInfoMap::find(StringRef key
) const {
504 std::string name
= getValuePackName(key
).str();
506 return symbolInfoMap
.find(name
);
509 SymbolInfoMap::const_iterator
510 SymbolInfoMap::findBoundSymbol(StringRef key
, DagNode node
, const Operator
&op
,
512 std::optional
<int> variadicSubIndex
) const {
513 return findBoundSymbol(
514 key
, SymbolInfo::getOperand(node
, &op
, argIndex
, variadicSubIndex
));
517 SymbolInfoMap::const_iterator
518 SymbolInfoMap::findBoundSymbol(StringRef key
,
519 const SymbolInfo
&symbolInfo
) const {
520 std::string name
= getValuePackName(key
).str();
521 auto range
= symbolInfoMap
.equal_range(name
);
523 for (auto it
= range
.first
; it
!= range
.second
; ++it
)
524 if (it
->second
.dagAndConstant
== symbolInfo
.dagAndConstant
)
527 return symbolInfoMap
.end();
530 std::pair
<SymbolInfoMap::iterator
, SymbolInfoMap::iterator
>
531 SymbolInfoMap::getRangeOfEqualElements(StringRef key
) {
532 std::string name
= getValuePackName(key
).str();
534 return symbolInfoMap
.equal_range(name
);
537 int SymbolInfoMap::count(StringRef key
) const {
538 std::string name
= getValuePackName(key
).str();
539 return symbolInfoMap
.count(name
);
542 int SymbolInfoMap::getStaticValueCount(StringRef symbol
) const {
543 StringRef name
= getValuePackName(symbol
);
544 if (name
!= symbol
) {
545 // If there is a trailing index inside symbol, it references just one
549 // Otherwise, find how many it represents by querying the symbol's info.
550 return find(name
)->second
.getStaticValueCount();
553 std::string
SymbolInfoMap::getValueAndRangeUse(StringRef symbol
,
555 const char *separator
) const {
557 StringRef name
= getValuePackName(symbol
, &index
);
559 auto it
= symbolInfoMap
.find(name
.str());
560 if (it
== symbolInfoMap
.end()) {
561 auto error
= formatv("referencing unbound symbol '{0}'", symbol
);
562 PrintFatalError(loc
, error
);
565 return it
->second
.getValueAndRangeUse(name
, index
, fmt
, separator
);
568 std::string
SymbolInfoMap::getAllRangeUse(StringRef symbol
, const char *fmt
,
569 const char *separator
) const {
571 StringRef name
= getValuePackName(symbol
, &index
);
573 auto it
= symbolInfoMap
.find(name
.str());
574 if (it
== symbolInfoMap
.end()) {
575 auto error
= formatv("referencing unbound symbol '{0}'", symbol
);
576 PrintFatalError(loc
, error
);
579 return it
->second
.getAllRangeUse(name
, index
, fmt
, separator
);
582 void SymbolInfoMap::assignUniqueAlternativeNames() {
583 llvm::StringSet
<> usedNames
;
585 for (auto symbolInfoIt
= symbolInfoMap
.begin();
586 symbolInfoIt
!= symbolInfoMap
.end();) {
587 auto range
= symbolInfoMap
.equal_range(symbolInfoIt
->first
);
588 auto startRange
= range
.first
;
589 auto endRange
= range
.second
;
591 auto operandName
= symbolInfoIt
->first
;
592 int startSearchIndex
= 0;
593 for (++startRange
; startRange
!= endRange
; ++startRange
) {
594 // Current operand name is not unique, find a unique one
595 // and set the alternative name.
596 for (int i
= startSearchIndex
;; ++i
) {
597 std::string alternativeName
= operandName
+ std::to_string(i
);
598 if (!usedNames
.contains(alternativeName
) &&
599 symbolInfoMap
.count(alternativeName
) == 0) {
600 usedNames
.insert(alternativeName
);
601 startRange
->second
.alternativeName
= alternativeName
;
602 startSearchIndex
= i
+ 1;
609 symbolInfoIt
= endRange
;
613 //===----------------------------------------------------------------------===//
615 //==----------------------------------------------------------------------===//
617 Pattern::Pattern(const llvm::Record
*def
, RecordOperatorMap
*mapper
)
618 : def(*def
), recordOpMap(mapper
) {}
620 DagNode
Pattern::getSourcePattern() const {
621 return DagNode(def
.getValueAsDag("sourcePattern"));
624 int Pattern::getNumResultPatterns() const {
625 auto *results
= def
.getValueAsListInit("resultPatterns");
626 return results
->size();
629 DagNode
Pattern::getResultPattern(unsigned index
) const {
630 auto *results
= def
.getValueAsListInit("resultPatterns");
631 return DagNode(cast
<llvm::DagInit
>(results
->getElement(index
)));
634 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap
&infoMap
) {
635 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
636 collectBoundSymbols(getSourcePattern(), infoMap
, /*isSrcPattern=*/true);
637 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
639 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
640 infoMap
.assignUniqueAlternativeNames();
641 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
644 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap
&infoMap
) {
645 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
646 for (int i
= 0, e
= getNumResultPatterns(); i
< e
; ++i
) {
647 auto pattern
= getResultPattern(i
);
648 collectBoundSymbols(pattern
, infoMap
, /*isSrcPattern=*/false);
650 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
653 const Operator
&Pattern::getSourceRootOp() {
654 return getSourcePattern().getDialectOp(recordOpMap
);
657 Operator
&Pattern::getDialectOp(DagNode node
) {
658 return node
.getDialectOp(recordOpMap
);
661 std::vector
<AppliedConstraint
> Pattern::getConstraints() const {
662 auto *listInit
= def
.getValueAsListInit("constraints");
663 std::vector
<AppliedConstraint
> ret
;
664 ret
.reserve(listInit
->size());
666 for (auto *it
: *listInit
) {
667 auto *dagInit
= dyn_cast
<llvm::DagInit
>(it
);
669 PrintFatalError(&def
, "all elements in Pattern multi-entity "
670 "constraints should be DAG nodes");
672 std::vector
<std::string
> entities
;
673 entities
.reserve(dagInit
->arg_size());
674 for (auto *argName
: dagInit
->getArgNames()) {
678 "operands to additional constraints can only be symbol references");
680 entities
.emplace_back(argName
->getValue());
683 ret
.emplace_back(cast
<llvm::DefInit
>(dagInit
->getOperator())->getDef(),
684 dagInit
->getNameStr(), std::move(entities
));
689 int Pattern::getNumSupplementalPatterns() const {
690 auto *results
= def
.getValueAsListInit("supplementalPatterns");
691 return results
->size();
694 DagNode
Pattern::getSupplementalPattern(unsigned index
) const {
695 auto *results
= def
.getValueAsListInit("supplementalPatterns");
696 return DagNode(cast
<llvm::DagInit
>(results
->getElement(index
)));
699 int Pattern::getBenefit() const {
700 // The initial benefit value is a heuristic with number of ops in the source
702 int initBenefit
= getSourcePattern().getNumOps();
703 llvm::DagInit
*delta
= def
.getValueAsDag("benefitDelta");
704 if (delta
->getNumArgs() != 1 || !isa
<llvm::IntInit
>(delta
->getArg(0))) {
705 PrintFatalError(&def
,
706 "The 'addBenefit' takes and only takes one integer value");
708 return initBenefit
+ dyn_cast
<llvm::IntInit
>(delta
->getArg(0))->getValue();
711 std::vector
<Pattern::IdentifierLine
> Pattern::getLocation() const {
712 std::vector
<std::pair
<StringRef
, unsigned>> result
;
713 result
.reserve(def
.getLoc().size());
714 for (auto loc
: def
.getLoc()) {
715 unsigned buf
= llvm::SrcMgr
.FindBufferContainingLoc(loc
);
716 assert(buf
&& "invalid source location");
718 llvm::SrcMgr
.getBufferInfo(buf
).Buffer
->getBufferIdentifier(),
719 llvm::SrcMgr
.getLineAndColumn(loc
, buf
).first
);
724 void Pattern::verifyBind(bool result
, StringRef symbolName
) {
726 auto err
= formatv("symbol '{0}' bound more than once", symbolName
);
727 PrintFatalError(&def
, err
);
731 void Pattern::collectBoundSymbols(DagNode tree
, SymbolInfoMap
&infoMap
,
733 auto treeName
= tree
.getSymbol();
734 auto numTreeArgs
= tree
.getNumArgs();
736 if (tree
.isNativeCodeCall()) {
737 if (!treeName
.empty()) {
739 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
740 << treeName
<< '\n');
742 infoMap
.bindValues(treeName
, tree
.getNumReturnsOfNativeCode()),
745 PrintFatalError(&def
,
746 formatv("binding symbol '{0}' to NativecodeCall in "
747 "MatchPattern is not supported",
752 for (int i
= 0; i
!= numTreeArgs
; ++i
) {
753 if (auto treeArg
= tree
.getArgAsNestedDag(i
)) {
754 // This DAG node argument is a DAG node itself. Go inside recursively.
755 collectBoundSymbols(treeArg
, infoMap
, isSrcPattern
);
762 // We can only bind symbols to arguments in source pattern. Those
763 // symbols are referenced in result patterns.
764 auto treeArgName
= tree
.getArgName(i
);
766 // `$_` is a special symbol meaning ignore the current argument.
767 if (!treeArgName
.empty() && treeArgName
!= "_") {
768 DagLeaf leaf
= tree
.getArgAsLeaf(i
);
770 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
771 if (leaf
.isUnspecified()) {
772 // This is case of $c, a Value without any constraints.
773 verifyBind(infoMap
.bindValue(treeArgName
), treeArgName
);
775 auto constraint
= leaf
.getAsConstraint();
776 bool isAttr
= leaf
.isAttrMatcher() || leaf
.isEnumAttrCase() ||
777 leaf
.isConstantAttr() ||
778 constraint
.getKind() == Constraint::Kind::CK_Attr
;
781 // This is case of $a, a binding to a certain attribute.
782 verifyBind(infoMap
.bindAttr(treeArgName
), treeArgName
);
786 // This is case of $b, a binding to a certain type.
787 verifyBind(infoMap
.bindValue(treeArgName
), treeArgName
);
795 if (tree
.isOperation()) {
796 auto &op
= getDialectOp(tree
);
797 auto numOpArgs
= op
.getNumArgs();
800 // We need to exclude the trailing directives and `either` directive groups
801 // two operands of the operation.
802 int numDirectives
= 0;
803 for (int i
= numTreeArgs
- 1; i
>= 0; --i
) {
804 if (auto dagArg
= tree
.getArgAsNestedDag(i
)) {
805 if (dagArg
.isLocationDirective() || dagArg
.isReturnTypeDirective())
807 else if (dagArg
.isEither())
812 if (numOpArgs
!= numTreeArgs
- numDirectives
+ numEither
) {
814 formatv("op '{0}' argument number mismatch: "
815 "{1} in pattern vs. {2} in definition",
816 op
.getOperationName(), numTreeArgs
+ numEither
, numOpArgs
);
817 PrintFatalError(&def
, err
);
820 // The name attached to the DAG node's operator is for representing the
821 // results generated from this op. It should be remembered as bound results.
822 if (!treeName
.empty()) {
823 LLVM_DEBUG(llvm::dbgs()
824 << "found symbol bound to op result: " << treeName
<< '\n');
825 verifyBind(infoMap
.bindOpResult(treeName
, op
), treeName
);
828 // The operand in `either` DAG should be bound to the operation in the
830 auto collectSymbolInEither
= [&](DagNode parent
, DagNode tree
,
832 for (int i
= 0; i
< tree
.getNumArgs(); ++i
, ++opArgIdx
) {
833 if (DagNode subTree
= tree
.getArgAsNestedDag(i
)) {
834 collectBoundSymbols(subTree
, infoMap
, isSrcPattern
);
836 auto argName
= tree
.getArgName(i
);
837 if (!argName
.empty() && argName
!= "_") {
838 verifyBind(infoMap
.bindOpArgument(parent
, argName
, op
, opArgIdx
),
845 // The operand in `variadic` DAG should be bound to the operation in the
846 // parent DagNode. The range index must be included as well to distinguish
847 // (potentially) repeating argName within the `variadic` DAG.
848 auto collectSymbolInVariadic
= [&](DagNode parent
, DagNode tree
,
850 auto treeName
= tree
.getSymbol();
851 if (!treeName
.empty()) {
852 // If treeName is specified, bind to the full variadic operand_range.
853 verifyBind(infoMap
.bindOpArgument(parent
, treeName
, op
, opArgIdx
,
858 for (int i
= 0; i
< tree
.getNumArgs(); ++i
) {
859 if (DagNode subTree
= tree
.getArgAsNestedDag(i
)) {
860 collectBoundSymbols(subTree
, infoMap
, isSrcPattern
);
862 auto argName
= tree
.getArgName(i
);
863 if (!argName
.empty() && argName
!= "_") {
864 verifyBind(infoMap
.bindOpArgument(parent
, argName
, op
, opArgIdx
,
865 /*variadicSubIndex=*/i
),
872 for (int i
= 0, opArgIdx
= 0; i
!= numTreeArgs
; ++i
, ++opArgIdx
) {
873 if (auto treeArg
= tree
.getArgAsNestedDag(i
)) {
874 if (treeArg
.isEither()) {
875 collectSymbolInEither(tree
, treeArg
, opArgIdx
);
876 // `either` DAG is *flattened*. For example,
878 // (FooOp (either arg0, arg1), arg2)
882 // (FooOp arg0, arg1, arg2)
884 } else if (treeArg
.isVariadic()) {
885 collectSymbolInVariadic(tree
, treeArg
, opArgIdx
);
887 // This DAG node argument is a DAG node itself. Go inside recursively.
888 collectBoundSymbols(treeArg
, infoMap
, isSrcPattern
);
894 // We can only bind symbols to op arguments in source pattern. Those
895 // symbols are referenced in result patterns.
896 auto treeArgName
= tree
.getArgName(i
);
897 // `$_` is a special symbol meaning ignore the current argument.
898 if (!treeArgName
.empty() && treeArgName
!= "_") {
899 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
900 << treeArgName
<< '\n');
901 verifyBind(infoMap
.bindOpArgument(tree
, treeArgName
, op
, opArgIdx
),
909 if (!treeName
.empty()) {
911 &def
, formatv("binding symbol '{0}' to non-operation/native code call "
912 "unsupported right now",