[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / TableGen / Pattern.cpp
blobafb69e7cc55866a4c74f0b8b7a4f70953d962449
1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
12 //===----------------------------------------------------------------------===//
14 #include <utility>
16 #include "mlir/TableGen/Pattern.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 #define DEBUG_TYPE "mlir-tblgen-pattern"
26 using namespace mlir;
27 using namespace tblgen;
29 using llvm::formatv;
31 //===----------------------------------------------------------------------===//
32 // DagLeaf
33 //===----------------------------------------------------------------------===//
35 bool DagLeaf::isUnspecified() const {
36 return isa_and_nonnull<llvm::UnsetInit>(def);
39 bool DagLeaf::isOperandMatcher() const {
40 // Operand matchers specify a type constraint.
41 return isSubClassOf("TypeConstraint");
44 bool DagLeaf::isAttrMatcher() const {
45 // Attribute matchers specify an attribute constraint.
46 return isSubClassOf("AttrConstraint");
49 bool DagLeaf::isNativeCodeCall() const {
50 return isSubClassOf("NativeCodeCall");
53 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
55 bool DagLeaf::isEnumAttrCase() const {
56 return isSubClassOf("EnumAttrCaseInfo");
59 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
61 Constraint DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(def)->getDef());
67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast<llvm::DefInit>(def));
72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast<llvm::DefInit>(def));
77 std::string DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
86 int DagLeaf::getNumReturnsOfNativeCode() const {
87 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
88 return cast<llvm::DefInit>(def)->getDef()->getValueAsInt("numReturns");
91 std::string DagLeaf::getStringAttr() const {
92 assert(isStringAttr() && "the DAG leaf must be string attribute");
93 return def->getAsUnquotedString();
95 bool DagLeaf::isSubClassOf(StringRef superclass) const {
96 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
97 return defInit->getDef()->isSubClassOf(superclass);
98 return false;
101 void DagLeaf::print(raw_ostream &os) const {
102 if (def)
103 def->print(os);
106 //===----------------------------------------------------------------------===//
107 // DagNode
108 //===----------------------------------------------------------------------===//
110 bool DagNode::isNativeCodeCall() const {
111 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
112 return defInit->getDef()->isSubClassOf("NativeCodeCall");
113 return false;
116 bool DagNode::isOperation() const {
117 return !isNativeCodeCall() && !isReplaceWithValue() &&
118 !isLocationDirective() && !isReturnTypeDirective() && !isEither() &&
119 !isVariadic();
122 llvm::StringRef DagNode::getNativeCodeTemplate() const {
123 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
124 return cast<llvm::DefInit>(node->getOperator())
125 ->getDef()
126 ->getValueAsString("expression");
129 int DagNode::getNumReturnsOfNativeCode() const {
130 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
131 return cast<llvm::DefInit>(node->getOperator())
132 ->getDef()
133 ->getValueAsInt("numReturns");
136 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
138 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
139 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
140 auto it = mapper->find(opDef);
141 if (it != mapper->end())
142 return *it->second;
143 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
144 .first->second;
147 int DagNode::getNumOps() const {
148 // We want to get number of operations recursively involved in the DAG tree.
149 // All other directives should be excluded.
150 int count = isOperation() ? 1 : 0;
151 for (int i = 0, e = getNumArgs(); i != e; ++i) {
152 if (auto child = getArgAsNestedDag(i))
153 count += child.getNumOps();
155 return count;
158 int DagNode::getNumArgs() const { return node->getNumArgs(); }
160 bool DagNode::isNestedDagArg(unsigned index) const {
161 return isa<llvm::DagInit>(node->getArg(index));
164 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
165 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
168 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
169 assert(!isNestedDagArg(index));
170 return DagLeaf(node->getArg(index));
173 StringRef DagNode::getArgName(unsigned index) const {
174 return node->getArgNameStr(index);
177 bool DagNode::isReplaceWithValue() const {
178 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
179 return dagOpDef->getName() == "replaceWithValue";
182 bool DagNode::isLocationDirective() const {
183 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
184 return dagOpDef->getName() == "location";
187 bool DagNode::isReturnTypeDirective() const {
188 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
189 return dagOpDef->getName() == "returnType";
192 bool DagNode::isEither() const {
193 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
194 return dagOpDef->getName() == "either";
197 bool DagNode::isVariadic() const {
198 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
199 return dagOpDef->getName() == "variadic";
202 void DagNode::print(raw_ostream &os) const {
203 if (node)
204 node->print(os);
207 //===----------------------------------------------------------------------===//
208 // SymbolInfoMap
209 //===----------------------------------------------------------------------===//
211 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
212 int idx = -1;
213 auto [name, indexStr] = symbol.rsplit("__");
215 if (indexStr.consumeInteger(10, idx)) {
216 // The second part is not an index; we return the whole symbol as-is.
217 return symbol;
219 if (index) {
220 *index = idx;
222 return name;
225 SymbolInfoMap::SymbolInfo::SymbolInfo(
226 const Operator *op, SymbolInfo::Kind kind,
227 std::optional<DagAndConstant> dagAndConstant)
228 : op(op), kind(kind), dagAndConstant(dagAndConstant) {}
230 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
231 switch (kind) {
232 case Kind::Attr:
233 case Kind::Operand:
234 case Kind::Value:
235 return 1;
236 case Kind::Result:
237 return op->getNumResults();
238 case Kind::MultipleValues:
239 return getSize();
241 llvm_unreachable("unknown kind");
244 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
245 return alternativeName ? *alternativeName : name.str();
248 std::string SymbolInfoMap::SymbolInfo::getVarTypeStr(StringRef name) const {
249 LLVM_DEBUG(llvm::dbgs() << "getVarTypeStr for '" << name << "': ");
250 switch (kind) {
251 case Kind::Attr: {
252 if (op)
253 return op->getArg(getArgIndex())
254 .get<NamedAttribute *>()
255 ->attr.getStorageType()
256 .str();
257 // TODO(suderman): Use a more exact type when available.
258 return "::mlir::Attribute";
260 case Kind::Operand: {
261 // Use operand range for captured operands (to support potential variadic
262 // operands).
263 return "::mlir::Operation::operand_range";
265 case Kind::Value: {
266 return "::mlir::Value";
268 case Kind::MultipleValues: {
269 return "::mlir::ValueRange";
271 case Kind::Result: {
272 // Use the op itself for captured results.
273 return op->getQualCppClassName();
276 llvm_unreachable("unknown kind");
279 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
280 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
281 std::string varInit = kind == Kind::Operand ? "(op0->getOperands())" : "";
282 return std::string(
283 formatv("{0} {1}{2};\n", getVarTypeStr(name), getVarName(name), varInit));
286 std::string SymbolInfoMap::SymbolInfo::getArgDecl(StringRef name) const {
287 LLVM_DEBUG(llvm::dbgs() << "getArgDecl for '" << name << "': ");
288 return std::string(
289 formatv("{0} &{1}", getVarTypeStr(name), getVarName(name)));
292 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
293 StringRef name, int index, const char *fmt, const char *separator) const {
294 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
295 switch (kind) {
296 case Kind::Attr: {
297 assert(index < 0);
298 auto repl = formatv(fmt, name);
299 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
300 return std::string(repl);
302 case Kind::Operand: {
303 assert(index < 0);
304 auto *operand = op->getArg(getArgIndex()).get<NamedTypeConstraint *>();
305 // If this operand is variadic and this SymbolInfo doesn't have a range
306 // index, then return the full variadic operand_range. Otherwise, return
307 // the value itself.
308 if (operand->isVariableLength() && !getVariadicSubIndex().has_value()) {
309 auto repl = formatv(fmt, name);
310 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
311 return std::string(repl);
313 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
314 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
315 return std::string(repl);
317 case Kind::Result: {
318 // If `index` is greater than zero, then we are referencing a specific
319 // result of a multi-result op. The result can still be variadic.
320 if (index >= 0) {
321 std::string v =
322 std::string(formatv("{0}.getODSResults({1})", name, index));
323 if (!op->getResult(index).isVariadic())
324 v = std::string(formatv("(*{0}.begin())", v));
325 auto repl = formatv(fmt, v);
326 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
327 return std::string(repl);
330 // If this op has no result at all but still we bind a symbol to it, it
331 // means we want to capture the op itself.
332 if (op->getNumResults() == 0) {
333 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
334 return formatv(fmt, name);
337 // We are referencing all results of the multi-result op. A specific result
338 // can either be a value or a range. Then join them with `separator`.
339 SmallVector<std::string, 4> values;
340 values.reserve(op->getNumResults());
342 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
343 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
344 if (!op->getResult(i).isVariadic()) {
345 v = std::string(formatv("(*{0}.begin())", v));
347 values.push_back(std::string(formatv(fmt, v)));
349 auto repl = llvm::join(values, separator);
350 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
351 return repl;
353 case Kind::Value: {
354 assert(index < 0);
355 assert(op == nullptr);
356 auto repl = formatv(fmt, name);
357 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
358 return std::string(repl);
360 case Kind::MultipleValues: {
361 assert(op == nullptr);
362 assert(index < getSize());
363 if (index >= 0) {
364 std::string repl =
365 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
366 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
367 return repl;
369 // If it doesn't specify certain element, unpack them all.
370 auto repl =
371 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
372 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
373 return std::string(repl);
376 llvm_unreachable("unknown kind");
379 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
380 StringRef name, int index, const char *fmt, const char *separator) const {
381 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
382 switch (kind) {
383 case Kind::Attr:
384 case Kind::Operand: {
385 assert(index < 0 && "only allowed for symbol bound to result");
386 auto repl = formatv(fmt, name);
387 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
388 return std::string(repl);
390 case Kind::Result: {
391 if (index >= 0) {
392 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
393 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
394 return std::string(repl);
397 // We are referencing all results of the multi-result op. Each result should
398 // have a value range, and then join them with `separator`.
399 SmallVector<std::string, 4> values;
400 values.reserve(op->getNumResults());
402 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
403 values.push_back(std::string(
404 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
406 auto repl = llvm::join(values, separator);
407 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
408 return repl;
410 case Kind::Value: {
411 assert(index < 0 && "only allowed for symbol bound to result");
412 assert(op == nullptr);
413 auto repl = formatv(fmt, formatv("{{{0}}", name));
414 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
415 return std::string(repl);
417 case Kind::MultipleValues: {
418 assert(op == nullptr);
419 assert(index < getSize());
420 if (index >= 0) {
421 std::string repl =
422 formatv(fmt, std::string(formatv("{0}[{1}]", name, index)));
423 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
424 return repl;
426 auto repl =
427 formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name)));
428 LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n");
429 return std::string(repl);
432 llvm_unreachable("unknown kind");
435 bool SymbolInfoMap::bindOpArgument(DagNode node, StringRef symbol,
436 const Operator &op, int argIndex,
437 std::optional<int> variadicSubIndex) {
438 StringRef name = getValuePackName(symbol);
439 if (name != symbol) {
440 auto error = formatv(
441 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
442 PrintFatalError(loc, error);
445 auto symInfo =
446 op.getArg(argIndex).is<NamedAttribute *>()
447 ? SymbolInfo::getAttr(&op, argIndex)
448 : SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex);
450 std::string key = symbol.str();
451 if (symbolInfoMap.count(key)) {
452 // Only non unique name for the operand is supported.
453 if (symInfo.kind != SymbolInfo::Kind::Operand) {
454 return false;
457 // Cannot add new operand if there is already non operand with the same
458 // name.
459 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
460 return false;
464 symbolInfoMap.emplace(key, symInfo);
465 return true;
468 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
469 std::string name = getValuePackName(symbol).str();
470 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
472 return symbolInfoMap.count(inserted->first) == 1;
475 bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) {
476 std::string name = getValuePackName(symbol).str();
477 if (numValues > 1)
478 return bindMultipleValues(name, numValues);
479 return bindValue(name);
482 bool SymbolInfoMap::bindValue(StringRef symbol) {
483 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
484 return symbolInfoMap.count(inserted->first) == 1;
487 bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) {
488 std::string name = getValuePackName(symbol).str();
489 auto inserted =
490 symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues));
491 return symbolInfoMap.count(inserted->first) == 1;
494 bool SymbolInfoMap::bindAttr(StringRef symbol) {
495 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
496 return symbolInfoMap.count(inserted->first) == 1;
499 bool SymbolInfoMap::contains(StringRef symbol) const {
500 return find(symbol) != symbolInfoMap.end();
503 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
504 std::string name = getValuePackName(key).str();
506 return symbolInfoMap.find(name);
509 SymbolInfoMap::const_iterator
510 SymbolInfoMap::findBoundSymbol(StringRef key, DagNode node, const Operator &op,
511 int argIndex,
512 std::optional<int> variadicSubIndex) const {
513 return findBoundSymbol(
514 key, SymbolInfo::getOperand(node, &op, argIndex, variadicSubIndex));
517 SymbolInfoMap::const_iterator
518 SymbolInfoMap::findBoundSymbol(StringRef key,
519 const SymbolInfo &symbolInfo) const {
520 std::string name = getValuePackName(key).str();
521 auto range = symbolInfoMap.equal_range(name);
523 for (auto it = range.first; it != range.second; ++it)
524 if (it->second.dagAndConstant == symbolInfo.dagAndConstant)
525 return it;
527 return symbolInfoMap.end();
530 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
531 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
532 std::string name = getValuePackName(key).str();
534 return symbolInfoMap.equal_range(name);
537 int SymbolInfoMap::count(StringRef key) const {
538 std::string name = getValuePackName(key).str();
539 return symbolInfoMap.count(name);
542 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
543 StringRef name = getValuePackName(symbol);
544 if (name != symbol) {
545 // If there is a trailing index inside symbol, it references just one
546 // static value.
547 return 1;
549 // Otherwise, find how many it represents by querying the symbol's info.
550 return find(name)->second.getStaticValueCount();
553 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
554 const char *fmt,
555 const char *separator) const {
556 int index = -1;
557 StringRef name = getValuePackName(symbol, &index);
559 auto it = symbolInfoMap.find(name.str());
560 if (it == symbolInfoMap.end()) {
561 auto error = formatv("referencing unbound symbol '{0}'", symbol);
562 PrintFatalError(loc, error);
565 return it->second.getValueAndRangeUse(name, index, fmt, separator);
568 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
569 const char *separator) const {
570 int index = -1;
571 StringRef name = getValuePackName(symbol, &index);
573 auto it = symbolInfoMap.find(name.str());
574 if (it == symbolInfoMap.end()) {
575 auto error = formatv("referencing unbound symbol '{0}'", symbol);
576 PrintFatalError(loc, error);
579 return it->second.getAllRangeUse(name, index, fmt, separator);
582 void SymbolInfoMap::assignUniqueAlternativeNames() {
583 llvm::StringSet<> usedNames;
585 for (auto symbolInfoIt = symbolInfoMap.begin();
586 symbolInfoIt != symbolInfoMap.end();) {
587 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
588 auto startRange = range.first;
589 auto endRange = range.second;
591 auto operandName = symbolInfoIt->first;
592 int startSearchIndex = 0;
593 for (++startRange; startRange != endRange; ++startRange) {
594 // Current operand name is not unique, find a unique one
595 // and set the alternative name.
596 for (int i = startSearchIndex;; ++i) {
597 std::string alternativeName = operandName + std::to_string(i);
598 if (!usedNames.contains(alternativeName) &&
599 symbolInfoMap.count(alternativeName) == 0) {
600 usedNames.insert(alternativeName);
601 startRange->second.alternativeName = alternativeName;
602 startSearchIndex = i + 1;
604 break;
609 symbolInfoIt = endRange;
613 //===----------------------------------------------------------------------===//
614 // Pattern
615 //==----------------------------------------------------------------------===//
617 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
618 : def(*def), recordOpMap(mapper) {}
620 DagNode Pattern::getSourcePattern() const {
621 return DagNode(def.getValueAsDag("sourcePattern"));
624 int Pattern::getNumResultPatterns() const {
625 auto *results = def.getValueAsListInit("resultPatterns");
626 return results->size();
629 DagNode Pattern::getResultPattern(unsigned index) const {
630 auto *results = def.getValueAsListInit("resultPatterns");
631 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
634 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
635 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
636 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
637 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
639 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
640 infoMap.assignUniqueAlternativeNames();
641 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
644 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
645 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
646 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
647 auto pattern = getResultPattern(i);
648 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
650 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
653 const Operator &Pattern::getSourceRootOp() {
654 return getSourcePattern().getDialectOp(recordOpMap);
657 Operator &Pattern::getDialectOp(DagNode node) {
658 return node.getDialectOp(recordOpMap);
661 std::vector<AppliedConstraint> Pattern::getConstraints() const {
662 auto *listInit = def.getValueAsListInit("constraints");
663 std::vector<AppliedConstraint> ret;
664 ret.reserve(listInit->size());
666 for (auto *it : *listInit) {
667 auto *dagInit = dyn_cast<llvm::DagInit>(it);
668 if (!dagInit)
669 PrintFatalError(&def, "all elements in Pattern multi-entity "
670 "constraints should be DAG nodes");
672 std::vector<std::string> entities;
673 entities.reserve(dagInit->arg_size());
674 for (auto *argName : dagInit->getArgNames()) {
675 if (!argName) {
676 PrintFatalError(
677 &def,
678 "operands to additional constraints can only be symbol references");
680 entities.emplace_back(argName->getValue());
683 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
684 dagInit->getNameStr(), std::move(entities));
686 return ret;
689 int Pattern::getNumSupplementalPatterns() const {
690 auto *results = def.getValueAsListInit("supplementalPatterns");
691 return results->size();
694 DagNode Pattern::getSupplementalPattern(unsigned index) const {
695 auto *results = def.getValueAsListInit("supplementalPatterns");
696 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
699 int Pattern::getBenefit() const {
700 // The initial benefit value is a heuristic with number of ops in the source
701 // pattern.
702 int initBenefit = getSourcePattern().getNumOps();
703 llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
704 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705 PrintFatalError(&def,
706 "The 'addBenefit' takes and only takes one integer value");
708 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
711 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
712 std::vector<std::pair<StringRef, unsigned>> result;
713 result.reserve(def.getLoc().size());
714 for (auto loc : def.getLoc()) {
715 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
716 assert(buf && "invalid source location");
717 result.emplace_back(
718 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
719 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
721 return result;
724 void Pattern::verifyBind(bool result, StringRef symbolName) {
725 if (!result) {
726 auto err = formatv("symbol '{0}' bound more than once", symbolName);
727 PrintFatalError(&def, err);
731 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
732 bool isSrcPattern) {
733 auto treeName = tree.getSymbol();
734 auto numTreeArgs = tree.getNumArgs();
736 if (tree.isNativeCodeCall()) {
737 if (!treeName.empty()) {
738 if (!isSrcPattern) {
739 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: "
740 << treeName << '\n');
741 verifyBind(
742 infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()),
743 treeName);
744 } else {
745 PrintFatalError(&def,
746 formatv("binding symbol '{0}' to NativecodeCall in "
747 "MatchPattern is not supported",
748 treeName));
752 for (int i = 0; i != numTreeArgs; ++i) {
753 if (auto treeArg = tree.getArgAsNestedDag(i)) {
754 // This DAG node argument is a DAG node itself. Go inside recursively.
755 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
756 continue;
759 if (!isSrcPattern)
760 continue;
762 // We can only bind symbols to arguments in source pattern. Those
763 // symbols are referenced in result patterns.
764 auto treeArgName = tree.getArgName(i);
766 // `$_` is a special symbol meaning ignore the current argument.
767 if (!treeArgName.empty() && treeArgName != "_") {
768 DagLeaf leaf = tree.getArgAsLeaf(i);
770 // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c),
771 if (leaf.isUnspecified()) {
772 // This is case of $c, a Value without any constraints.
773 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
774 } else {
775 auto constraint = leaf.getAsConstraint();
776 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
777 leaf.isConstantAttr() ||
778 constraint.getKind() == Constraint::Kind::CK_Attr;
780 if (isAttr) {
781 // This is case of $a, a binding to a certain attribute.
782 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
783 continue;
786 // This is case of $b, a binding to a certain type.
787 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
792 return;
795 if (tree.isOperation()) {
796 auto &op = getDialectOp(tree);
797 auto numOpArgs = op.getNumArgs();
798 int numEither = 0;
800 // We need to exclude the trailing directives and `either` directive groups
801 // two operands of the operation.
802 int numDirectives = 0;
803 for (int i = numTreeArgs - 1; i >= 0; --i) {
804 if (auto dagArg = tree.getArgAsNestedDag(i)) {
805 if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
806 ++numDirectives;
807 else if (dagArg.isEither())
808 ++numEither;
812 if (numOpArgs != numTreeArgs - numDirectives + numEither) {
813 auto err =
814 formatv("op '{0}' argument number mismatch: "
815 "{1} in pattern vs. {2} in definition",
816 op.getOperationName(), numTreeArgs + numEither, numOpArgs);
817 PrintFatalError(&def, err);
820 // The name attached to the DAG node's operator is for representing the
821 // results generated from this op. It should be remembered as bound results.
822 if (!treeName.empty()) {
823 LLVM_DEBUG(llvm::dbgs()
824 << "found symbol bound to op result: " << treeName << '\n');
825 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
828 // The operand in `either` DAG should be bound to the operation in the
829 // parent DagNode.
830 auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
831 int opArgIdx) {
832 for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
833 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
834 collectBoundSymbols(subTree, infoMap, isSrcPattern);
835 } else {
836 auto argName = tree.getArgName(i);
837 if (!argName.empty() && argName != "_") {
838 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
839 argName);
845 // The operand in `variadic` DAG should be bound to the operation in the
846 // parent DagNode. The range index must be included as well to distinguish
847 // (potentially) repeating argName within the `variadic` DAG.
848 auto collectSymbolInVariadic = [&](DagNode parent, DagNode tree,
849 int opArgIdx) {
850 auto treeName = tree.getSymbol();
851 if (!treeName.empty()) {
852 // If treeName is specified, bind to the full variadic operand_range.
853 verifyBind(infoMap.bindOpArgument(parent, treeName, op, opArgIdx,
854 std::nullopt),
855 treeName);
858 for (int i = 0; i < tree.getNumArgs(); ++i) {
859 if (DagNode subTree = tree.getArgAsNestedDag(i)) {
860 collectBoundSymbols(subTree, infoMap, isSrcPattern);
861 } else {
862 auto argName = tree.getArgName(i);
863 if (!argName.empty() && argName != "_") {
864 verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx,
865 /*variadicSubIndex=*/i),
866 argName);
872 for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
873 if (auto treeArg = tree.getArgAsNestedDag(i)) {
874 if (treeArg.isEither()) {
875 collectSymbolInEither(tree, treeArg, opArgIdx);
876 // `either` DAG is *flattened*. For example,
878 // (FooOp (either arg0, arg1), arg2)
880 // can be viewed as:
882 // (FooOp arg0, arg1, arg2)
883 ++opArgIdx;
884 } else if (treeArg.isVariadic()) {
885 collectSymbolInVariadic(tree, treeArg, opArgIdx);
886 } else {
887 // This DAG node argument is a DAG node itself. Go inside recursively.
888 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
890 continue;
893 if (isSrcPattern) {
894 // We can only bind symbols to op arguments in source pattern. Those
895 // symbols are referenced in result patterns.
896 auto treeArgName = tree.getArgName(i);
897 // `$_` is a special symbol meaning ignore the current argument.
898 if (!treeArgName.empty() && treeArgName != "_") {
899 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
900 << treeArgName << '\n');
901 verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
902 treeArgName);
906 return;
909 if (!treeName.empty()) {
910 PrintFatalError(
911 &def, formatv("binding symbol '{0}' to non-operation/native code call "
912 "unsupported right now",
913 treeName));