[mlir] Fix resource printing in the presence of multiple dialects
[llvm-project.git] / mlir / lib / IR / AsmPrinter.cpp
blobf626d39cb6cf27d8eae82c0d4dea41ec7f0cab1b
1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/DialectResourceBlobManager.h"
24 #include "mlir/IR/IntegerSet.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/Operation.h"
28 #include "mlir/IR/Verifier.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/MapVector.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/ScopedHashTable.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallString.h"
37 #include "llvm/ADT/StringExtras.h"
38 #include "llvm/ADT/StringSet.h"
39 #include "llvm/ADT/TypeSwitch.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/Endian.h"
43 #include "llvm/Support/Regex.h"
44 #include "llvm/Support/SaveAndRestore.h"
45 #include "llvm/Support/Threading.h"
47 #include <optional>
48 #include <tuple>
50 using namespace mlir;
51 using namespace mlir::detail;
53 #define DEBUG_TYPE "mlir-asm-printer"
55 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
57 void OperationName::dump() const { print(llvm::errs()); }
59 //===--------------------------------------------------------------------===//
60 // AsmParser
61 //===--------------------------------------------------------------------===//
63 AsmParser::~AsmParser() = default;
64 DialectAsmParser::~DialectAsmParser() = default;
65 OpAsmParser::~OpAsmParser() = default;
67 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
69 //===----------------------------------------------------------------------===//
70 // DialectAsmPrinter
71 //===----------------------------------------------------------------------===//
73 DialectAsmPrinter::~DialectAsmPrinter() = default;
75 //===----------------------------------------------------------------------===//
76 // OpAsmPrinter
77 //===----------------------------------------------------------------------===//
79 OpAsmPrinter::~OpAsmPrinter() = default;
81 void OpAsmPrinter::printFunctionalType(Operation *op) {
82 auto &os = getStream();
83 os << '(';
84 llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
85 // Print the types of null values as <<NULL TYPE>>.
86 *this << (operand ? operand.getType() : Type());
87 });
88 os << ") -> ";
90 // Print the result list. We don't parenthesize single result types unless
91 // it is a function (avoiding a grammar ambiguity).
92 bool wrapped = op->getNumResults() != 1;
93 if (!wrapped && op->getResult(0).getType() &&
94 llvm::isa<FunctionType>(op->getResult(0).getType()))
95 wrapped = true;
97 if (wrapped)
98 os << '(';
100 llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
101 // Print the types of null values as <<NULL TYPE>>.
102 *this << (result ? result.getType() : Type());
105 if (wrapped)
106 os << ')';
109 //===----------------------------------------------------------------------===//
110 // Operation OpAsm interface.
111 //===----------------------------------------------------------------------===//
113 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
114 #include "mlir/IR/OpAsmInterface.cpp.inc"
116 LogicalResult
117 OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const {
118 return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
119 << "' for dialect '" << getDialect()->getNamespace()
120 << "'";
123 //===----------------------------------------------------------------------===//
124 // OpPrintingFlags
125 //===----------------------------------------------------------------------===//
127 namespace {
128 /// This struct contains command line options that can be used to initialize
129 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
130 /// for global command line options.
131 struct AsmPrinterOptions {
132 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
133 "mlir-print-elementsattrs-with-hex-if-larger",
134 llvm::cl::desc(
135 "Print DenseElementsAttrs with a hex string that have "
136 "more elements than the given upper limit (use -1 to disable)")};
138 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
139 "mlir-elide-elementsattrs-if-larger",
140 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
141 "more elements than the given upper limit")};
143 llvm::cl::opt<bool> printDebugInfoOpt{
144 "mlir-print-debuginfo", llvm::cl::init(false),
145 llvm::cl::desc("Print debug info in MLIR output")};
147 llvm::cl::opt<bool> printPrettyDebugInfoOpt{
148 "mlir-pretty-debuginfo", llvm::cl::init(false),
149 llvm::cl::desc("Print pretty debug info in MLIR output")};
151 // Use the generic op output form in the operation printer even if the custom
152 // form is defined.
153 llvm::cl::opt<bool> printGenericOpFormOpt{
154 "mlir-print-op-generic", llvm::cl::init(false),
155 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
157 llvm::cl::opt<bool> assumeVerifiedOpt{
158 "mlir-print-assume-verified", llvm::cl::init(false),
159 llvm::cl::desc("Skip op verification when using custom printers"),
160 llvm::cl::Hidden};
162 llvm::cl::opt<bool> printLocalScopeOpt{
163 "mlir-print-local-scope", llvm::cl::init(false),
164 llvm::cl::desc("Print with local scope and inline information (eliding "
165 "aliases for attributes, types, and locations")};
167 llvm::cl::opt<bool> printValueUsers{
168 "mlir-print-value-users", llvm::cl::init(false),
169 llvm::cl::desc(
170 "Print users of operation results and block arguments as a comment")};
172 } // namespace
174 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
176 /// Register a set of useful command-line options that can be used to configure
177 /// various flags within the AsmPrinter.
178 void mlir::registerAsmPrinterCLOptions() {
179 // Make sure that the options struct has been initialized.
180 *clOptions;
183 /// Initialize the printing flags with default supplied by the cl::opts above.
184 OpPrintingFlags::OpPrintingFlags()
185 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
186 printGenericOpFormFlag(false), skipRegionsFlag(false),
187 assumeVerifiedFlag(false), printLocalScope(false),
188 printValueUsersFlag(false) {
189 // Initialize based upon command line options, if they are available.
190 if (!clOptions.isConstructed())
191 return;
192 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
193 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
194 printDebugInfoFlag = clOptions->printDebugInfoOpt;
195 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
196 printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
197 assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
198 printLocalScope = clOptions->printLocalScopeOpt;
199 printValueUsersFlag = clOptions->printValueUsers;
202 /// Enable the elision of large elements attributes, by printing a '...'
203 /// instead of the element data, when the number of elements is greater than
204 /// `largeElementLimit`. Note: The IR generated with this option is not
205 /// parsable.
206 OpPrintingFlags &
207 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
208 elementsAttrElementLimit = largeElementLimit;
209 return *this;
212 /// Enable printing of debug information. If 'prettyForm' is set to true,
213 /// debug information is printed in a more readable 'pretty' form.
214 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool enable,
215 bool prettyForm) {
216 printDebugInfoFlag = enable;
217 printDebugInfoPrettyFormFlag = prettyForm;
218 return *this;
221 /// Always print operations in the generic form.
222 OpPrintingFlags &OpPrintingFlags::printGenericOpForm(bool enable) {
223 printGenericOpFormFlag = enable;
224 return *this;
227 /// Always skip Regions.
228 OpPrintingFlags &OpPrintingFlags::skipRegions(bool skip) {
229 skipRegionsFlag = skip;
230 return *this;
233 /// Do not verify the operation when using custom operation printers.
234 OpPrintingFlags &OpPrintingFlags::assumeVerified() {
235 assumeVerifiedFlag = true;
236 return *this;
239 /// Use local scope when printing the operation. This allows for using the
240 /// printer in a more localized and thread-safe setting, but may not necessarily
241 /// be identical of what the IR will look like when dumping the full module.
242 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
243 printLocalScope = true;
244 return *this;
247 /// Print users of values as comments.
248 OpPrintingFlags &OpPrintingFlags::printValueUsers() {
249 printValueUsersFlag = true;
250 return *this;
253 /// Return if the given ElementsAttr should be elided.
254 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
255 return elementsAttrElementLimit &&
256 *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
257 !llvm::isa<SplatElementsAttr>(attr);
260 /// Return the size limit for printing large ElementsAttr.
261 std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
262 return elementsAttrElementLimit;
265 /// Return if debug information should be printed.
266 bool OpPrintingFlags::shouldPrintDebugInfo() const {
267 return printDebugInfoFlag;
270 /// Return if debug information should be printed in the pretty form.
271 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
272 return printDebugInfoPrettyFormFlag;
275 /// Return if operations should be printed in the generic form.
276 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
277 return printGenericOpFormFlag;
280 /// Return if Region should be skipped.
281 bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
283 /// Return if operation verification should be skipped.
284 bool OpPrintingFlags::shouldAssumeVerified() const {
285 return assumeVerifiedFlag;
288 /// Return if the printer should use local scope when dumping the IR.
289 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
291 /// Return if the printer should print users of values.
292 bool OpPrintingFlags::shouldPrintValueUsers() const {
293 return printValueUsersFlag;
296 /// Returns true if an ElementsAttr with the given number of elements should be
297 /// printed with hex.
298 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
299 // Check to see if a command line option was provided for the limit.
300 if (clOptions.isConstructed()) {
301 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
302 // -1 is used to disable hex printing.
303 if (clOptions->printElementsAttrWithHexIfLarger == -1)
304 return false;
305 return numElements > clOptions->printElementsAttrWithHexIfLarger;
309 // Otherwise, default to printing with hex if the number of elements is >100.
310 return numElements > 100;
313 //===----------------------------------------------------------------------===//
314 // NewLineCounter
315 //===----------------------------------------------------------------------===//
317 namespace {
318 /// This class is a simple formatter that emits a new line when inputted into a
319 /// stream, that enables counting the number of newlines emitted. This class
320 /// should be used whenever emitting newlines in the printer.
321 struct NewLineCounter {
322 unsigned curLine = 1;
325 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
326 ++newLine.curLine;
327 return os << '\n';
329 } // namespace
331 //===----------------------------------------------------------------------===//
332 // AsmPrinter::Impl
333 //===----------------------------------------------------------------------===//
335 namespace mlir {
336 class AsmPrinter::Impl {
337 public:
338 Impl(raw_ostream &os, AsmStateImpl &state);
339 explicit Impl(Impl &other) : Impl(other.os, other.state) {}
341 /// Returns the output stream of the printer.
342 raw_ostream &getStream() { return os; }
344 template <typename Container, typename UnaryFunctor>
345 inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
346 llvm::interleaveComma(c, os, eachFn);
349 /// This enum describes the different kinds of elision for the type of an
350 /// attribute when printing it.
351 enum class AttrTypeElision {
352 /// The type must not be elided,
353 Never,
354 /// The type may be elided when it matches the default used in the parser
355 /// (for example i64 is the default for integer attributes).
356 May,
357 /// The type must be elided.
358 Must
361 /// Print the given attribute or an alias.
362 void printAttribute(Attribute attr,
363 AttrTypeElision typeElision = AttrTypeElision::Never);
364 /// Print the given attribute without considering an alias.
365 void printAttributeImpl(Attribute attr,
366 AttrTypeElision typeElision = AttrTypeElision::Never);
368 /// Print the alias for the given attribute, return failure if no alias could
369 /// be printed.
370 LogicalResult printAlias(Attribute attr);
372 /// Print the given type or an alias.
373 void printType(Type type);
374 /// Print the given type.
375 void printTypeImpl(Type type);
377 /// Print the alias for the given type, return failure if no alias could
378 /// be printed.
379 LogicalResult printAlias(Type type);
381 /// Print the given location to the stream. If `allowAlias` is true, this
382 /// allows for the internal location to use an attribute alias.
383 void printLocation(LocationAttr loc, bool allowAlias = false);
385 /// Print a reference to the given resource that is owned by the given
386 /// dialect.
387 void printResourceHandle(const AsmDialectResourceHandle &resource);
389 void printAffineMap(AffineMap map);
390 void
391 printAffineExpr(AffineExpr expr,
392 function_ref<void(unsigned, bool)> printValueName = nullptr);
393 void printAffineConstraint(AffineExpr expr, bool isEq);
394 void printIntegerSet(IntegerSet set);
396 protected:
397 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
398 ArrayRef<StringRef> elidedAttrs = {},
399 bool withKeyword = false);
400 void printNamedAttribute(NamedAttribute attr);
401 void printTrailingLocation(Location loc, bool allowAlias = true);
402 void printLocationInternal(LocationAttr loc, bool pretty = false,
403 bool isTopLevel = false);
405 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
406 /// used instead of individual elements when the elements attr is large.
407 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
409 /// Print a dense string elements attribute.
410 void printDenseStringElementsAttr(DenseStringElementsAttr attr);
412 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
413 /// used instead of individual elements when the elements attr is large.
414 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
415 bool allowHex);
417 /// Print a dense array attribute.
418 void printDenseArrayAttr(DenseArrayAttr attr);
420 void printDialectAttribute(Attribute attr);
421 void printDialectType(Type type);
423 /// Print an escaped string, wrapped with "".
424 void printEscapedString(StringRef str);
426 /// Print a hex string, wrapped with "".
427 void printHexString(StringRef str);
428 void printHexString(ArrayRef<char> data);
430 /// This enum is used to represent the binding strength of the enclosing
431 /// context that an AffineExprStorage is being printed in, so we can
432 /// intelligently produce parens.
433 enum class BindingStrength {
434 Weak, // + and -
435 Strong, // All other binary operators.
437 void printAffineExprInternal(
438 AffineExpr expr, BindingStrength enclosingTightness,
439 function_ref<void(unsigned, bool)> printValueName = nullptr);
441 /// The output stream for the printer.
442 raw_ostream &os;
444 /// An underlying assembly printer state.
445 AsmStateImpl &state;
447 /// A set of flags to control the printer's behavior.
448 OpPrintingFlags printerFlags;
450 /// A tracker for the number of new lines emitted during printing.
451 NewLineCounter newLine;
453 } // namespace mlir
455 //===----------------------------------------------------------------------===//
456 // AliasInitializer
457 //===----------------------------------------------------------------------===//
459 namespace {
460 /// This class represents a specific instance of a symbol Alias.
461 class SymbolAlias {
462 public:
463 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
464 bool isDeferrable)
465 : name(name), suffixIndex(suffixIndex), isType(isType),
466 isDeferrable(isDeferrable) {}
468 /// Print this alias to the given stream.
469 void print(raw_ostream &os) const {
470 os << (isType ? "!" : "#") << name;
471 if (suffixIndex)
472 os << suffixIndex;
475 /// Returns true if this is a type alias.
476 bool isTypeAlias() const { return isType; }
478 /// Returns true if this alias supports deferred resolution when parsing.
479 bool canBeDeferred() const { return isDeferrable; }
481 private:
482 /// The main name of the alias.
483 StringRef name;
484 /// The suffix index of the alias.
485 uint32_t suffixIndex : 30;
486 /// A flag indicating whether this alias is for a type.
487 bool isType : 1;
488 /// A flag indicating whether this alias may be deferred or not.
489 bool isDeferrable : 1;
492 /// This class represents a utility that initializes the set of attribute and
493 /// type aliases, without the need to store the extra information within the
494 /// main AliasState class or pass it around via function arguments.
495 class AliasInitializer {
496 public:
497 AliasInitializer(
498 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
499 llvm::BumpPtrAllocator &aliasAllocator)
500 : interfaces(interfaces), aliasAllocator(aliasAllocator),
501 aliasOS(aliasBuffer) {}
503 void initialize(Operation *op, const OpPrintingFlags &printerFlags,
504 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
506 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
507 /// set to true if the originator of this attribute can resolve the alias
508 /// after parsing has completed (e.g. in the case of operation locations).
509 /// `elideType` indicates if the type of the attribute should be skipped when
510 /// looking for nested aliases. Returns the maximum alias depth of the
511 /// attribute, and the alias index of this attribute.
512 std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
513 bool elideType = false) {
514 return visitImpl(attr, aliases, canBeDeferred, elideType);
517 /// Visit the given type to see if it has an alias. `canBeDeferred` is
518 /// set to true if the originator of this attribute can resolve the alias
519 /// after parsing has completed. Returns the maximum alias depth of the type,
520 /// and the alias index of this type.
521 std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
522 return visitImpl(type, aliases, canBeDeferred);
525 private:
526 struct InProgressAliasInfo {
527 InProgressAliasInfo()
528 : aliasDepth(0), isType(false), canBeDeferred(false) {}
529 InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
530 : alias(alias), aliasDepth(1), isType(isType),
531 canBeDeferred(canBeDeferred) {}
533 bool operator<(const InProgressAliasInfo &rhs) const {
534 // Order first by depth, then by attr/type kind, and then by name.
535 if (aliasDepth != rhs.aliasDepth)
536 return aliasDepth < rhs.aliasDepth;
537 if (isType != rhs.isType)
538 return isType;
539 return alias < rhs.alias;
542 /// The alias for the attribute or type, or std::nullopt if the value has no
543 /// alias.
544 std::optional<StringRef> alias;
545 /// The alias depth of this attribute or type, i.e. an indication of the
546 /// relative ordering of when to print this alias.
547 unsigned aliasDepth : 30;
548 /// If this alias represents a type or an attribute.
549 bool isType : 1;
550 /// If this alias can be deferred or not.
551 bool canBeDeferred : 1;
552 /// Indices for child aliases.
553 SmallVector<size_t> childIndices;
556 /// Visit the given attribute or type to see if it has an alias.
557 /// `canBeDeferred` is set to true if the originator of this value can resolve
558 /// the alias after parsing has completed (e.g. in the case of operation
559 /// locations). Returns the maximum alias depth of the value, and its alias
560 /// index.
561 template <typename T, typename... PrintArgs>
562 std::pair<size_t, size_t>
563 visitImpl(T value,
564 llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
565 bool canBeDeferred, PrintArgs &&...printArgs);
567 /// Mark the given alias as non-deferrable.
568 void markAliasNonDeferrable(size_t aliasIndex);
570 /// Try to generate an alias for the provided symbol. If an alias is
571 /// generated, the provided alias mapping and reverse mapping are updated.
572 template <typename T>
573 void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
575 /// Given a collection of aliases and symbols, initialize a mapping from a
576 /// symbol to a given alias.
577 static void initializeAliases(
578 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
579 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
581 /// The set of asm interfaces within the context.
582 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
584 /// An allocator used for alias names.
585 llvm::BumpPtrAllocator &aliasAllocator;
587 /// The set of built aliases.
588 llvm::MapVector<const void *, InProgressAliasInfo> aliases;
590 /// Storage and stream used when generating an alias.
591 SmallString<32> aliasBuffer;
592 llvm::raw_svector_ostream aliasOS;
595 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
596 /// and merely collects the attributes and types that *would* be printed in a
597 /// normal print invocation so that we can generate proper aliases. This allows
598 /// for us to generate aliases only for the attributes and types that would be
599 /// in the output, and trims down unnecessary output.
600 class DummyAliasOperationPrinter : private OpAsmPrinter {
601 public:
602 explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
603 AliasInitializer &initializer)
604 : printerFlags(printerFlags), initializer(initializer) {}
606 /// Prints the entire operation with the custom assembly form, if available,
607 /// or the generic assembly form, otherwise.
608 void printCustomOrGenericOp(Operation *op) override {
609 // Visit the operation location.
610 if (printerFlags.shouldPrintDebugInfo())
611 initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
613 // If requested, always print the generic form.
614 if (!printerFlags.shouldPrintGenericOpForm()) {
615 op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
616 return;
619 // Otherwise print with the generic assembly form.
620 printGenericOp(op);
623 private:
624 /// Print the given operation in the generic form.
625 void printGenericOp(Operation *op, bool printOpName = true) override {
626 // Consider nested operations for aliases.
627 if (!printerFlags.shouldSkipRegions()) {
628 for (Region &region : op->getRegions())
629 printRegion(region, /*printEntryBlockArgs=*/true,
630 /*printBlockTerminators=*/true);
633 // Visit all the types used in the operation.
634 for (Type type : op->getOperandTypes())
635 printType(type);
636 for (Type type : op->getResultTypes())
637 printType(type);
639 // Consider the attributes of the operation for aliases.
640 for (const NamedAttribute &attr : op->getAttrs())
641 printAttribute(attr.getValue());
644 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
645 /// block are not printed. If 'printBlockTerminator' is false, the terminator
646 /// operation of the block is not printed.
647 void print(Block *block, bool printBlockArgs = true,
648 bool printBlockTerminator = true) {
649 // Consider the types of the block arguments for aliases if 'printBlockArgs'
650 // is set to true.
651 if (printBlockArgs) {
652 for (BlockArgument arg : block->getArguments()) {
653 printType(arg.getType());
655 // Visit the argument location.
656 if (printerFlags.shouldPrintDebugInfo())
657 // TODO: Allow deferring argument locations.
658 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
662 // Consider the operations within this block, ignoring the terminator if
663 // requested.
664 bool hasTerminator =
665 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
666 auto range = llvm::make_range(
667 block->begin(),
668 std::prev(block->end(),
669 (!hasTerminator || printBlockTerminator) ? 0 : 1));
670 for (Operation &op : range)
671 printCustomOrGenericOp(&op);
674 /// Print the given region.
675 void printRegion(Region &region, bool printEntryBlockArgs,
676 bool printBlockTerminators,
677 bool printEmptyBlock = false) override {
678 if (region.empty())
679 return;
680 if (printerFlags.shouldSkipRegions()) {
681 os << "{...}";
682 return;
685 auto *entryBlock = &region.front();
686 print(entryBlock, printEntryBlockArgs, printBlockTerminators);
687 for (Block &b : llvm::drop_begin(region, 1))
688 print(&b);
691 void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
692 bool omitType) override {
693 printType(arg.getType());
694 // Visit the argument location.
695 if (printerFlags.shouldPrintDebugInfo())
696 // TODO: Allow deferring argument locations.
697 initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
700 /// Consider the given type to be printed for an alias.
701 void printType(Type type) override { initializer.visit(type); }
703 /// Consider the given attribute to be printed for an alias.
704 void printAttribute(Attribute attr) override { initializer.visit(attr); }
705 void printAttributeWithoutType(Attribute attr) override {
706 printAttribute(attr);
708 LogicalResult printAlias(Attribute attr) override {
709 initializer.visit(attr);
710 return success();
712 LogicalResult printAlias(Type type) override {
713 initializer.visit(type);
714 return success();
717 /// Consider the given location to be printed for an alias.
718 void printOptionalLocationSpecifier(Location loc) override {
719 printAttribute(loc);
722 /// Print the given set of attributes with names not included within
723 /// 'elidedAttrs'.
724 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
725 ArrayRef<StringRef> elidedAttrs = {}) override {
726 if (attrs.empty())
727 return;
728 if (elidedAttrs.empty()) {
729 for (const NamedAttribute &attr : attrs)
730 printAttribute(attr.getValue());
731 return;
733 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
734 elidedAttrs.end());
735 for (const NamedAttribute &attr : attrs)
736 if (!elidedAttrsSet.contains(attr.getName().strref()))
737 printAttribute(attr.getValue());
739 void printOptionalAttrDictWithKeyword(
740 ArrayRef<NamedAttribute> attrs,
741 ArrayRef<StringRef> elidedAttrs = {}) override {
742 printOptionalAttrDict(attrs, elidedAttrs);
745 /// Return a null stream as the output stream, this will ignore any data fed
746 /// to it.
747 raw_ostream &getStream() const override { return os; }
749 /// The following are hooks of `OpAsmPrinter` that are not necessary for
750 /// determining potential aliases.
751 void printFloat(const APFloat &) override {}
752 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
753 void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
754 void printNewline() override {}
755 void increaseIndent() override {}
756 void decreaseIndent() override {}
757 void printOperand(Value) override {}
758 void printOperand(Value, raw_ostream &os) override {
759 // Users expect the output string to have at least the prefixed % to signal
760 // a value name. To maintain this invariant, emit a name even if it is
761 // guaranteed to go unused.
762 os << "%";
764 void printKeywordOrString(StringRef) override {}
765 void printResourceHandle(const AsmDialectResourceHandle &) override {}
766 void printSymbolName(StringRef) override {}
767 void printSuccessor(Block *) override {}
768 void printSuccessorAndUseList(Block *, ValueRange) override {}
769 void shadowRegionArgs(Region &, ValueRange) override {}
771 /// The printer flags to use when determining potential aliases.
772 const OpPrintingFlags &printerFlags;
774 /// The initializer to use when identifying aliases.
775 AliasInitializer &initializer;
777 /// A dummy output stream.
778 mutable llvm::raw_null_ostream os;
781 class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
782 public:
783 explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
784 bool canBeDeferred,
785 SmallVectorImpl<size_t> &childIndices)
786 : initializer(initializer), canBeDeferred(canBeDeferred),
787 childIndices(childIndices) {}
789 /// Print the given attribute/type, visiting any nested aliases that would be
790 /// generated as part of printing. Returns the maximum alias depth found while
791 /// printing the given value.
792 template <typename T, typename... PrintArgs>
793 size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
794 printAndVisitNestedAliasesImpl(value, printArgs...);
795 return maxAliasDepth;
798 private:
799 /// Print the given attribute/type, visiting any nested aliases that would be
800 /// generated as part of printing.
801 void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
802 if (!isa<BuiltinDialect>(attr.getDialect())) {
803 attr.getDialect().printAttribute(attr, *this);
805 // Process the builtin attributes.
806 } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
807 IntegerSetAttr, UnitAttr>(attr)) {
808 return;
809 } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
810 printAttribute(distinctAttr.getReferencedAttr());
811 } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
812 for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
813 printAttribute(nestedAttr.getName());
814 printAttribute(nestedAttr.getValue());
816 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
817 for (Attribute nestedAttr : arrayAttr.getValue())
818 printAttribute(nestedAttr);
819 } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
820 printType(typeAttr.getValue());
821 } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
822 printAttribute(locAttr.getFallbackLocation());
823 } else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
824 if (!isa<UnknownLoc>(locAttr.getChildLoc()))
825 printAttribute(locAttr.getChildLoc());
826 } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
827 printAttribute(locAttr.getCallee());
828 printAttribute(locAttr.getCaller());
829 } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
830 if (Attribute metadata = locAttr.getMetadata())
831 printAttribute(metadata);
832 for (Location nestedLoc : locAttr.getLocations())
833 printAttribute(nestedLoc);
836 // Don't print the type if we must elide it, or if it is a None type.
837 if (!elideType) {
838 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
839 Type attrType = typedAttr.getType();
840 if (!llvm::isa<NoneType>(attrType))
841 printType(attrType);
845 void printAndVisitNestedAliasesImpl(Type type) {
846 if (!isa<BuiltinDialect>(type.getDialect()))
847 return type.getDialect().printType(type, *this);
849 // Only visit the layout of memref if it isn't the identity.
850 if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
851 printType(memrefTy.getElementType());
852 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
853 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
854 printAttribute(memrefTy.getLayout());
855 if (memrefTy.getMemorySpace())
856 printAttribute(memrefTy.getMemorySpace());
857 return;
860 // For most builtin types, we can simply walk the sub elements.
861 auto visitFn = [&](auto element) {
862 if (element)
863 (void)printAlias(element);
865 type.walkImmediateSubElements(visitFn, visitFn);
868 /// Consider the given type to be printed for an alias.
869 void printType(Type type) override {
870 recordAliasResult(initializer.visit(type, canBeDeferred));
873 /// Consider the given attribute to be printed for an alias.
874 void printAttribute(Attribute attr) override {
875 recordAliasResult(initializer.visit(attr, canBeDeferred));
877 void printAttributeWithoutType(Attribute attr) override {
878 recordAliasResult(
879 initializer.visit(attr, canBeDeferred, /*elideType=*/true));
881 LogicalResult printAlias(Attribute attr) override {
882 printAttribute(attr);
883 return success();
885 LogicalResult printAlias(Type type) override {
886 printType(type);
887 return success();
890 /// Record the alias result of a child element.
891 void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
892 childIndices.push_back(aliasDepthAndIndex.second);
893 if (aliasDepthAndIndex.first > maxAliasDepth)
894 maxAliasDepth = aliasDepthAndIndex.first;
897 /// Return a null stream as the output stream, this will ignore any data fed
898 /// to it.
899 raw_ostream &getStream() const override { return os; }
901 /// The following are hooks of `DialectAsmPrinter` that are not necessary for
902 /// determining potential aliases.
903 void printFloat(const APFloat &) override {}
904 void printKeywordOrString(StringRef) override {}
905 void printSymbolName(StringRef) override {}
906 void printResourceHandle(const AsmDialectResourceHandle &) override {}
908 /// The initializer to use when identifying aliases.
909 AliasInitializer &initializer;
911 /// If the aliases visited by this printer can be deferred.
912 bool canBeDeferred;
914 /// The indices of child aliases.
915 SmallVectorImpl<size_t> &childIndices;
917 /// The maximum alias depth found by the printer.
918 size_t maxAliasDepth = 0;
920 /// A dummy output stream.
921 mutable llvm::raw_null_ostream os;
923 } // namespace
925 /// Sanitize the given name such that it can be used as a valid identifier. If
926 /// the string needs to be modified in any way, the provided buffer is used to
927 /// store the new copy,
928 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
929 StringRef allowedPunctChars = "$._-",
930 bool allowTrailingDigit = true) {
931 assert(!name.empty() && "Shouldn't have an empty name here");
933 auto copyNameToBuffer = [&] {
934 for (char ch : name) {
935 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
936 buffer.push_back(ch);
937 else if (ch == ' ')
938 buffer.push_back('_');
939 else
940 buffer.append(llvm::utohexstr((unsigned char)ch));
944 // Check to see if this name is valid. If it starts with a digit, then it
945 // could conflict with the autogenerated numeric ID's, so add an underscore
946 // prefix to avoid problems.
947 if (isdigit(name[0])) {
948 buffer.push_back('_');
949 copyNameToBuffer();
950 return buffer;
953 // If the name ends with a trailing digit, add a '_' to avoid potential
954 // conflicts with autogenerated ID's.
955 if (!allowTrailingDigit && isdigit(name.back())) {
956 copyNameToBuffer();
957 buffer.push_back('_');
958 return buffer;
961 // Check to see that the name consists of only valid identifier characters.
962 for (char ch : name) {
963 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
964 copyNameToBuffer();
965 return buffer;
969 // If there are no invalid characters, return the original name.
970 return name;
973 /// Given a collection of aliases and symbols, initialize a mapping from a
974 /// symbol to a given alias.
975 void AliasInitializer::initializeAliases(
976 llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
977 llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
978 SmallVector<std::pair<const void *, InProgressAliasInfo>, 0>
979 unprocessedAliases = visitedSymbols.takeVector();
980 llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
981 return lhs.second < rhs.second;
984 llvm::StringMap<unsigned> nameCounts;
985 for (auto &[symbol, aliasInfo] : unprocessedAliases) {
986 if (!aliasInfo.alias)
987 continue;
988 StringRef alias = *aliasInfo.alias;
989 unsigned nameIndex = nameCounts[alias]++;
990 symbolToAlias.insert(
991 {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
992 aliasInfo.canBeDeferred)});
996 void AliasInitializer::initialize(
997 Operation *op, const OpPrintingFlags &printerFlags,
998 llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
999 // Use a dummy printer when walking the IR so that we can collect the
1000 // attributes/types that will actually be used during printing when
1001 // considering aliases.
1002 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
1003 aliasPrinter.printCustomOrGenericOp(op);
1005 // Initialize the aliases.
1006 initializeAliases(aliases, attrTypeToAlias);
1009 template <typename T, typename... PrintArgs>
1010 std::pair<size_t, size_t> AliasInitializer::visitImpl(
1011 T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
1012 bool canBeDeferred, PrintArgs &&...printArgs) {
1013 auto [it, inserted] =
1014 aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
1015 size_t aliasIndex = std::distance(aliases.begin(), it);
1016 if (!inserted) {
1017 // Make sure that the alias isn't deferred if we don't permit it.
1018 if (!canBeDeferred)
1019 markAliasNonDeferrable(aliasIndex);
1020 return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
1023 // Try to generate an alias for this value.
1024 generateAlias(value, it->second, canBeDeferred);
1026 // Print the value, capturing any nested elements that require aliases.
1027 SmallVector<size_t> childAliases;
1028 DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
1029 size_t maxAliasDepth =
1030 printer.printAndVisitNestedAliases(value, printArgs...);
1032 // Make sure to recompute `it` in case the map was reallocated.
1033 it = std::next(aliases.begin(), aliasIndex);
1035 // If we had sub elements, update to account for the depth.
1036 it->second.childIndices = std::move(childAliases);
1037 if (maxAliasDepth)
1038 it->second.aliasDepth = maxAliasDepth + 1;
1040 // Propagate the alias depth of the value.
1041 return {(size_t)it->second.aliasDepth, aliasIndex};
1044 void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
1045 auto it = std::next(aliases.begin(), aliasIndex);
1046 it->second.canBeDeferred = false;
1048 // Propagate the non-deferrable flag to any child aliases.
1049 for (size_t childIndex : it->second.childIndices)
1050 markAliasNonDeferrable(childIndex);
1053 template <typename T>
1054 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
1055 bool canBeDeferred) {
1056 SmallString<32> nameBuffer;
1057 for (const auto &interface : interfaces) {
1058 OpAsmDialectInterface::AliasResult result =
1059 interface.getAlias(symbol, aliasOS);
1060 if (result == OpAsmDialectInterface::AliasResult::NoAlias)
1061 continue;
1062 nameBuffer = std::move(aliasBuffer);
1063 assert(!nameBuffer.empty() && "expected valid alias name");
1064 if (result == OpAsmDialectInterface::AliasResult::FinalAlias)
1065 break;
1068 if (nameBuffer.empty())
1069 return;
1071 SmallString<16> tempBuffer;
1072 StringRef name =
1073 sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1074 /*allowTrailingDigit=*/false);
1075 name = name.copy(aliasAllocator);
1076 alias = InProgressAliasInfo(name, /*isType=*/std::is_base_of_v<Type, T>,
1077 canBeDeferred);
1080 //===----------------------------------------------------------------------===//
1081 // AliasState
1082 //===----------------------------------------------------------------------===//
1084 namespace {
1085 /// This class manages the state for type and attribute aliases.
1086 class AliasState {
1087 public:
1088 // Initialize the internal aliases.
1089 void
1090 initialize(Operation *op, const OpPrintingFlags &printerFlags,
1091 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
1093 /// Get an alias for the given attribute if it has one and print it in `os`.
1094 /// Returns success if an alias was printed, failure otherwise.
1095 LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
1097 /// Get an alias for the given type if it has one and print it in `os`.
1098 /// Returns success if an alias was printed, failure otherwise.
1099 LogicalResult getAlias(Type ty, raw_ostream &os) const;
1101 /// Print all of the referenced aliases that can not be resolved in a deferred
1102 /// manner.
1103 void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1104 printAliases(p, newLine, /*isDeferred=*/false);
1107 /// Print all of the referenced aliases that support deferred resolution.
1108 void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1109 printAliases(p, newLine, /*isDeferred=*/true);
1112 private:
1113 /// Print all of the referenced aliases that support the provided resolution
1114 /// behavior.
1115 void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1116 bool isDeferred);
1118 /// Mapping between attribute/type and alias.
1119 llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
1121 /// An allocator used for alias names.
1122 llvm::BumpPtrAllocator aliasAllocator;
1124 } // namespace
1126 void AliasState::initialize(
1127 Operation *op, const OpPrintingFlags &printerFlags,
1128 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
1129 AliasInitializer initializer(interfaces, aliasAllocator);
1130 initializer.initialize(op, printerFlags, attrTypeToAlias);
1133 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
1134 auto it = attrTypeToAlias.find(attr.getAsOpaquePointer());
1135 if (it == attrTypeToAlias.end())
1136 return failure();
1137 it->second.print(os);
1138 return success();
1141 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
1142 auto it = attrTypeToAlias.find(ty.getAsOpaquePointer());
1143 if (it == attrTypeToAlias.end())
1144 return failure();
1146 it->second.print(os);
1147 return success();
1150 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1151 bool isDeferred) {
1152 auto filterFn = [=](const auto &aliasIt) {
1153 return aliasIt.second.canBeDeferred() == isDeferred;
1155 for (auto &[opaqueSymbol, alias] :
1156 llvm::make_filter_range(attrTypeToAlias, filterFn)) {
1157 alias.print(p.getStream());
1158 p.getStream() << " = ";
1160 if (alias.isTypeAlias()) {
1161 // TODO: Support nested aliases in mutable types.
1162 Type type = Type::getFromOpaquePointer(opaqueSymbol);
1163 if (type.hasTrait<TypeTrait::IsMutable>())
1164 p.getStream() << type;
1165 else
1166 p.printTypeImpl(type);
1167 } else {
1168 // TODO: Support nested aliases in mutable attributes.
1169 Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
1170 if (attr.hasTrait<AttributeTrait::IsMutable>())
1171 p.getStream() << attr;
1172 else
1173 p.printAttributeImpl(attr);
1176 p.getStream() << newLine;
1180 //===----------------------------------------------------------------------===//
1181 // SSANameState
1182 //===----------------------------------------------------------------------===//
1184 namespace {
1185 /// Info about block printing: a number which is its position in the visitation
1186 /// order, and a name that is used to print reference to it, e.g. ^bb42.
1187 struct BlockInfo {
1188 int ordering;
1189 StringRef name;
1192 /// This class manages the state of SSA value names.
1193 class SSANameState {
1194 public:
1195 /// A sentinel value used for values with names set.
1196 enum : unsigned { NameSentinel = ~0U };
1198 SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
1199 SSANameState() = default;
1201 /// Print the SSA identifier for the given value to 'stream'. If
1202 /// 'printResultNo' is true, it also presents the result number ('#' number)
1203 /// of this value.
1204 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
1206 /// Print the operation identifier.
1207 void printOperationID(Operation *op, raw_ostream &stream) const;
1209 /// Return the result indices for each of the result groups registered by this
1210 /// operation, or empty if none exist.
1211 ArrayRef<int> getOpResultGroups(Operation *op);
1213 /// Get the info for the given block.
1214 BlockInfo getBlockInfo(Block *block);
1216 /// Renumber the arguments for the specified region to the same names as the
1217 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
1218 /// details.
1219 void shadowRegionArgs(Region &region, ValueRange namesToUse);
1221 private:
1222 /// Number the SSA values within the given IR unit.
1223 void numberValuesInRegion(Region &region);
1224 void numberValuesInBlock(Block &block);
1225 void numberValuesInOp(Operation &op);
1227 /// Given a result of an operation 'result', find the result group head
1228 /// 'lookupValue' and the result of 'result' within that group in
1229 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
1230 /// has more than 1 result.
1231 void getResultIDAndNumber(OpResult result, Value &lookupValue,
1232 std::optional<int> &lookupResultNo) const;
1234 /// Set a special value name for the given value.
1235 void setValueName(Value value, StringRef name);
1237 /// Uniques the given value name within the printer. If the given name
1238 /// conflicts, it is automatically renamed.
1239 StringRef uniqueValueName(StringRef name);
1241 /// This is the value ID for each SSA value. If this returns NameSentinel,
1242 /// then the valueID has an entry in valueNames.
1243 DenseMap<Value, unsigned> valueIDs;
1244 DenseMap<Value, StringRef> valueNames;
1246 /// When printing users of values, an operation without a result might
1247 /// be the user. This map holds ids for such operations.
1248 DenseMap<Operation *, unsigned> operationIDs;
1250 /// This is a map of operations that contain multiple named result groups,
1251 /// i.e. there may be multiple names for the results of the operation. The
1252 /// value of this map are the result numbers that start a result group.
1253 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
1255 /// This maps blocks to there visitation number in the current region as well
1256 /// as the string representing their name.
1257 DenseMap<Block *, BlockInfo> blockNames;
1259 /// This keeps track of all of the non-numeric names that are in flight,
1260 /// allowing us to check for duplicates.
1261 /// Note: the value of the map is unused.
1262 llvm::ScopedHashTable<StringRef, char> usedNames;
1263 llvm::BumpPtrAllocator usedNameAllocator;
1265 /// This is the next value ID to assign in numbering.
1266 unsigned nextValueID = 0;
1267 /// This is the next ID to assign to a region entry block argument.
1268 unsigned nextArgumentID = 0;
1269 /// This is the next ID to assign when a name conflict is detected.
1270 unsigned nextConflictID = 0;
1272 /// These are the printing flags. They control, eg., whether to print in
1273 /// generic form.
1274 OpPrintingFlags printerFlags;
1276 } // namespace
1278 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
1279 : printerFlags(printerFlags) {
1280 llvm::SaveAndRestore valueIDSaver(nextValueID);
1281 llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
1282 llvm::SaveAndRestore conflictIDSaver(nextConflictID);
1284 // The naming context includes `nextValueID`, `nextArgumentID`,
1285 // `nextConflictID` and `usedNames` scoped HashTable. This information is
1286 // carried from the parent region.
1287 using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
1288 using NamingContext =
1289 std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
1291 // Allocator for UsedNamesScopeTy
1292 llvm::BumpPtrAllocator allocator;
1294 // Add a scope for the top level operation.
1295 auto *topLevelNamesScope =
1296 new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
1298 SmallVector<NamingContext, 8> nameContext;
1299 for (Region &region : op->getRegions())
1300 nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
1301 nextConflictID, topLevelNamesScope));
1303 numberValuesInOp(*op);
1305 while (!nameContext.empty()) {
1306 Region *region;
1307 UsedNamesScopeTy *parentScope;
1308 std::tie(region, nextValueID, nextArgumentID, nextConflictID, parentScope) =
1309 nameContext.pop_back_val();
1311 // When we switch from one subtree to another, pop the scopes(needless)
1312 // until the parent scope.
1313 while (usedNames.getCurScope() != parentScope) {
1314 usedNames.getCurScope()->~UsedNamesScopeTy();
1315 assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
1316 "top level parentScope must be a nullptr");
1319 // Add a scope for the current region.
1320 auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
1321 UsedNamesScopeTy(usedNames);
1323 numberValuesInRegion(*region);
1325 for (Operation &op : region->getOps())
1326 for (Region &region : op.getRegions())
1327 nameContext.push_back(std::make_tuple(&region, nextValueID,
1328 nextArgumentID, nextConflictID,
1329 curNamesScope));
1332 // Manually remove all the scopes.
1333 while (usedNames.getCurScope() != nullptr)
1334 usedNames.getCurScope()->~UsedNamesScopeTy();
1337 void SSANameState::printValueID(Value value, bool printResultNo,
1338 raw_ostream &stream) const {
1339 if (!value) {
1340 stream << "<<NULL VALUE>>";
1341 return;
1344 std::optional<int> resultNo;
1345 auto lookupValue = value;
1347 // If this is an operation result, collect the head lookup value of the result
1348 // group and the result number of 'result' within that group.
1349 if (OpResult result = dyn_cast<OpResult>(value))
1350 getResultIDAndNumber(result, lookupValue, resultNo);
1352 auto it = valueIDs.find(lookupValue);
1353 if (it == valueIDs.end()) {
1354 stream << "<<UNKNOWN SSA VALUE>>";
1355 return;
1358 stream << '%';
1359 if (it->second != NameSentinel) {
1360 stream << it->second;
1361 } else {
1362 auto nameIt = valueNames.find(lookupValue);
1363 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1364 stream << nameIt->second;
1367 if (resultNo && printResultNo)
1368 stream << '#' << *resultNo;
1371 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1372 auto it = operationIDs.find(op);
1373 if (it == operationIDs.end()) {
1374 stream << "<<UNKNOWN OPERATION>>";
1375 } else {
1376 stream << '%' << it->second;
1380 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1381 auto it = opResultGroups.find(op);
1382 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1385 BlockInfo SSANameState::getBlockInfo(Block *block) {
1386 auto it = blockNames.find(block);
1387 BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1388 return it != blockNames.end() ? it->second : invalidBlock;
1391 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1392 assert(!region.empty() && "cannot shadow arguments of an empty region");
1393 assert(region.getNumArguments() == namesToUse.size() &&
1394 "incorrect number of names passed in");
1395 assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1396 "only KnownIsolatedFromAbove ops can shadow names");
1398 SmallVector<char, 16> nameStr;
1399 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1400 auto nameToUse = namesToUse[i];
1401 if (nameToUse == nullptr)
1402 continue;
1403 auto nameToReplace = region.getArgument(i);
1405 nameStr.clear();
1406 llvm::raw_svector_ostream nameStream(nameStr);
1407 printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1409 // Entry block arguments should already have a pretty "arg" name.
1410 assert(valueIDs[nameToReplace] == NameSentinel);
1412 // Use the name without the leading %.
1413 auto name = StringRef(nameStream.str()).drop_front();
1415 // Overwrite the name.
1416 valueNames[nameToReplace] = name.copy(usedNameAllocator);
1420 void SSANameState::numberValuesInRegion(Region &region) {
1421 auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1422 assert(!valueIDs.count(arg) && "arg numbered multiple times");
1423 assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
1424 "arg not defined in current region");
1425 setValueName(arg, name);
1428 if (!printerFlags.shouldPrintGenericOpForm()) {
1429 if (Operation *op = region.getParentOp()) {
1430 if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1431 asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1435 // Number the values within this region in a breadth-first order.
1436 unsigned nextBlockID = 0;
1437 for (auto &block : region) {
1438 // Each block gets a unique ID, and all of the operations within it get
1439 // numbered as well.
1440 auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1441 if (blockInfoIt.second) {
1442 // This block hasn't been named through `getAsmBlockArgumentNames`, use
1443 // default `^bbNNN` format.
1444 std::string name;
1445 llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1446 blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1448 blockInfoIt.first->second.ordering = nextBlockID++;
1450 numberValuesInBlock(block);
1454 void SSANameState::numberValuesInBlock(Block &block) {
1455 // Number the block arguments. We give entry block arguments a special name
1456 // 'arg'.
1457 bool isEntryBlock = block.isEntryBlock();
1458 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1459 llvm::raw_svector_ostream specialName(specialNameBuffer);
1460 for (auto arg : block.getArguments()) {
1461 if (valueIDs.count(arg))
1462 continue;
1463 if (isEntryBlock) {
1464 specialNameBuffer.resize(strlen("arg"));
1465 specialName << nextArgumentID++;
1467 setValueName(arg, specialName.str());
1470 // Number the operations in this block.
1471 for (auto &op : block)
1472 numberValuesInOp(op);
1475 void SSANameState::numberValuesInOp(Operation &op) {
1476 // Function used to set the special result names for the operation.
1477 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1478 auto setResultNameFn = [&](Value result, StringRef name) {
1479 assert(!valueIDs.count(result) && "result numbered multiple times");
1480 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1481 setValueName(result, name);
1483 // Record the result number for groups not anchored at 0.
1484 if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
1485 resultGroups.push_back(resultNo);
1487 // Operations can customize the printing of block names in OpAsmOpInterface.
1488 auto setBlockNameFn = [&](Block *block, StringRef name) {
1489 assert(block->getParentOp() == &op &&
1490 "getAsmBlockArgumentNames callback invoked on a block not directly "
1491 "nested under the current operation");
1492 assert(!blockNames.count(block) && "block numbered multiple times");
1493 SmallString<16> tmpBuffer{"^"};
1494 name = sanitizeIdentifier(name, tmpBuffer);
1495 if (name.data() != tmpBuffer.data()) {
1496 tmpBuffer.append(name);
1497 name = tmpBuffer.str();
1499 name = name.copy(usedNameAllocator);
1500 blockNames[block] = {-1, name};
1503 if (!printerFlags.shouldPrintGenericOpForm()) {
1504 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1505 asmInterface.getAsmBlockNames(setBlockNameFn);
1506 asmInterface.getAsmResultNames(setResultNameFn);
1510 unsigned numResults = op.getNumResults();
1511 if (numResults == 0) {
1512 // If value users should be printed, operations with no result need an id.
1513 if (printerFlags.shouldPrintValueUsers()) {
1514 if (operationIDs.try_emplace(&op, nextValueID).second)
1515 ++nextValueID;
1517 return;
1519 Value resultBegin = op.getResult(0);
1521 // If the first result wasn't numbered, give it a default number.
1522 if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1523 ++nextValueID;
1525 // If this operation has multiple result groups, mark it.
1526 if (resultGroups.size() != 1) {
1527 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1528 opResultGroups.try_emplace(&op, std::move(resultGroups));
1532 void SSANameState::getResultIDAndNumber(
1533 OpResult result, Value &lookupValue,
1534 std::optional<int> &lookupResultNo) const {
1535 Operation *owner = result.getOwner();
1536 if (owner->getNumResults() == 1)
1537 return;
1538 int resultNo = result.getResultNumber();
1540 // If this operation has multiple result groups, we will need to find the
1541 // one corresponding to this result.
1542 auto resultGroupIt = opResultGroups.find(owner);
1543 if (resultGroupIt == opResultGroups.end()) {
1544 // If not, just use the first result.
1545 lookupResultNo = resultNo;
1546 lookupValue = owner->getResult(0);
1547 return;
1550 // Find the correct index using a binary search, as the groups are ordered.
1551 ArrayRef<int> resultGroups = resultGroupIt->second;
1552 const auto *it = llvm::upper_bound(resultGroups, resultNo);
1553 int groupResultNo = 0, groupSize = 0;
1555 // If there are no smaller elements, the last result group is the lookup.
1556 if (it == resultGroups.end()) {
1557 groupResultNo = resultGroups.back();
1558 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1559 } else {
1560 // Otherwise, the previous element is the lookup.
1561 groupResultNo = *std::prev(it);
1562 groupSize = *it - groupResultNo;
1565 // We only record the result number for a group of size greater than 1.
1566 if (groupSize != 1)
1567 lookupResultNo = resultNo - groupResultNo;
1568 lookupValue = owner->getResult(groupResultNo);
1571 void SSANameState::setValueName(Value value, StringRef name) {
1572 // If the name is empty, the value uses the default numbering.
1573 if (name.empty()) {
1574 valueIDs[value] = nextValueID++;
1575 return;
1578 valueIDs[value] = NameSentinel;
1579 valueNames[value] = uniqueValueName(name);
1582 StringRef SSANameState::uniqueValueName(StringRef name) {
1583 SmallString<16> tmpBuffer;
1584 name = sanitizeIdentifier(name, tmpBuffer);
1586 // Check to see if this name is already unique.
1587 if (!usedNames.count(name)) {
1588 name = name.copy(usedNameAllocator);
1589 } else {
1590 // Otherwise, we had a conflict - probe until we find a unique name. This
1591 // is guaranteed to terminate (and usually in a single iteration) because it
1592 // generates new names by incrementing nextConflictID.
1593 SmallString<64> probeName(name);
1594 probeName.push_back('_');
1595 while (true) {
1596 probeName += llvm::utostr(nextConflictID++);
1597 if (!usedNames.count(probeName)) {
1598 name = probeName.str().copy(usedNameAllocator);
1599 break;
1601 probeName.resize(name.size() + 1);
1605 usedNames.insert(name, char());
1606 return name;
1609 //===----------------------------------------------------------------------===//
1610 // DistinctState
1611 //===----------------------------------------------------------------------===//
1613 namespace {
1614 /// This class manages the state for distinct attributes.
1615 class DistinctState {
1616 public:
1617 /// Returns a unique identifier for the given distinct attribute.
1618 uint64_t getId(DistinctAttr distinctAttr);
1620 private:
1621 uint64_t distinctCounter = 0;
1622 DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1624 } // namespace
1626 uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1627 auto [it, inserted] =
1628 distinctAttrMap.try_emplace(distinctAttr, distinctCounter);
1629 if (inserted)
1630 distinctCounter++;
1631 return it->getSecond();
1634 //===----------------------------------------------------------------------===//
1635 // Resources
1636 //===----------------------------------------------------------------------===//
1638 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1639 AsmResourceBuilder::~AsmResourceBuilder() = default;
1640 AsmResourceParser::~AsmResourceParser() = default;
1641 AsmResourcePrinter::~AsmResourcePrinter() = default;
1643 StringRef mlir::toString(AsmResourceEntryKind kind) {
1644 switch (kind) {
1645 case AsmResourceEntryKind::Blob:
1646 return "blob";
1647 case AsmResourceEntryKind::Bool:
1648 return "bool";
1649 case AsmResourceEntryKind::String:
1650 return "string";
1652 llvm_unreachable("unknown AsmResourceEntryKind");
1655 AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
1656 std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
1657 if (!collection)
1658 collection = std::make_unique<ResourceCollection>(key);
1659 return *collection;
1662 std::vector<std::unique_ptr<AsmResourcePrinter>>
1663 FallbackAsmResourceMap::getPrinters() {
1664 std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
1665 for (auto &it : keyToResources) {
1666 ResourceCollection *collection = it.second.get();
1667 auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
1668 return collection->buildResources(op, builder);
1670 printers.emplace_back(
1671 AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
1673 return printers;
1676 LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
1677 AsmParsedResourceEntry &entry) {
1678 switch (entry.getKind()) {
1679 case AsmResourceEntryKind::Blob: {
1680 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
1681 if (failed(blob))
1682 return failure();
1683 resources.emplace_back(entry.getKey(), std::move(*blob));
1684 return success();
1686 case AsmResourceEntryKind::Bool: {
1687 FailureOr<bool> value = entry.parseAsBool();
1688 if (failed(value))
1689 return failure();
1690 resources.emplace_back(entry.getKey(), *value);
1691 break;
1693 case AsmResourceEntryKind::String: {
1694 FailureOr<std::string> str = entry.parseAsString();
1695 if (failed(str))
1696 return failure();
1697 resources.emplace_back(entry.getKey(), std::move(*str));
1698 break;
1701 return success();
1704 void FallbackAsmResourceMap::ResourceCollection::buildResources(
1705 Operation *op, AsmResourceBuilder &builder) const {
1706 for (const auto &entry : resources) {
1707 if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
1708 builder.buildBlob(entry.key, *value);
1709 else if (const auto *value = std::get_if<bool>(&entry.value))
1710 builder.buildBool(entry.key, *value);
1711 else if (const auto *value = std::get_if<std::string>(&entry.value))
1712 builder.buildString(entry.key, *value);
1713 else
1714 llvm_unreachable("unknown AsmResourceEntryKind");
1718 //===----------------------------------------------------------------------===//
1719 // AsmState
1720 //===----------------------------------------------------------------------===//
1722 namespace mlir {
1723 namespace detail {
1724 class AsmStateImpl {
1725 public:
1726 explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1727 AsmState::LocationMap *locationMap)
1728 : interfaces(op->getContext()), nameState(op, printerFlags),
1729 printerFlags(printerFlags), locationMap(locationMap) {}
1730 explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1731 AsmState::LocationMap *locationMap)
1732 : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
1734 /// Initialize the alias state to enable the printing of aliases.
1735 void initializeAliases(Operation *op) {
1736 aliasState.initialize(op, printerFlags, interfaces);
1739 /// Get the state used for aliases.
1740 AliasState &getAliasState() { return aliasState; }
1742 /// Get the state used for SSA names.
1743 SSANameState &getSSANameState() { return nameState; }
1745 /// Get the state used for distinct attribute identifiers.
1746 DistinctState &getDistinctState() { return distinctState; }
1748 /// Return the dialects within the context that implement
1749 /// OpAsmDialectInterface.
1750 DialectInterfaceCollection<OpAsmDialectInterface> &getDialectInterfaces() {
1751 return interfaces;
1754 /// Return the non-dialect resource printers.
1755 auto getResourcePrinters() {
1756 return llvm::make_pointee_range(externalResourcePrinters);
1759 /// Get the printer flags.
1760 const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1762 /// Register the location, line and column, within the buffer that the given
1763 /// operation was printed at.
1764 void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1765 if (locationMap)
1766 (*locationMap)[op] = std::make_pair(line, col);
1769 /// Return the referenced dialect resources within the printer.
1770 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1771 getDialectResources() {
1772 return dialectResources;
1775 private:
1776 /// Collection of OpAsm interfaces implemented in the context.
1777 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1779 /// A collection of non-dialect resource printers.
1780 SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1782 /// A set of dialect resources that were referenced during printing.
1783 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
1785 /// The state used for attribute and type aliases.
1786 AliasState aliasState;
1788 /// The state used for SSA value names.
1789 SSANameState nameState;
1791 /// The state used for distinct attribute identifiers.
1792 DistinctState distinctState;
1794 /// Flags that control op output.
1795 OpPrintingFlags printerFlags;
1797 /// An optional location map to be populated.
1798 AsmState::LocationMap *locationMap;
1800 // Allow direct access to the impl fields.
1801 friend AsmState;
1803 } // namespace detail
1804 } // namespace mlir
1806 /// Verifies the operation and switches to generic op printing if verification
1807 /// fails. We need to do this because custom print functions may fail for
1808 /// invalid ops.
1809 static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
1810 OpPrintingFlags printerFlags) {
1811 if (printerFlags.shouldPrintGenericOpForm() ||
1812 printerFlags.shouldAssumeVerified())
1813 return printerFlags;
1815 LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Verifying operation: "
1816 << op->getName() << "\n");
1818 // Ignore errors emitted by the verifier. We check the thread id to avoid
1819 // consuming other threads' errors.
1820 auto parentThreadId = llvm::get_threadid();
1821 ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1822 if (parentThreadId == llvm::get_threadid()) {
1823 LLVM_DEBUG({
1824 diag.print(llvm::dbgs());
1825 llvm::dbgs() << "\n";
1827 return success();
1829 return failure();
1831 if (failed(verify(op))) {
1832 LLVM_DEBUG(llvm::dbgs()
1833 << DEBUG_TYPE << ": '" << op->getName()
1834 << "' failed to verify and will be printed in generic form\n");
1835 printerFlags.printGenericOpForm();
1838 return printerFlags;
1841 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1842 LocationMap *locationMap, FallbackAsmResourceMap *map)
1843 : impl(std::make_unique<AsmStateImpl>(
1844 op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
1845 if (map)
1846 attachFallbackResourcePrinter(*map);
1848 AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1849 LocationMap *locationMap, FallbackAsmResourceMap *map)
1850 : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
1851 if (map)
1852 attachFallbackResourcePrinter(*map);
1854 AsmState::~AsmState() = default;
1856 const OpPrintingFlags &AsmState::getPrinterFlags() const {
1857 return impl->getPrinterFlags();
1860 void AsmState::attachResourcePrinter(
1861 std::unique_ptr<AsmResourcePrinter> printer) {
1862 impl->externalResourcePrinters.emplace_back(std::move(printer));
1865 DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
1866 AsmState::getDialectResources() const {
1867 return impl->getDialectResources();
1870 //===----------------------------------------------------------------------===//
1871 // AsmPrinter::Impl
1872 //===----------------------------------------------------------------------===//
1874 AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
1875 : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
1877 void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
1878 // Check to see if we are printing debug information.
1879 if (!printerFlags.shouldPrintDebugInfo())
1880 return;
1882 os << " ";
1883 printLocation(loc, /*allowAlias=*/allowAlias);
1886 void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
1887 bool isTopLevel) {
1888 // If this isn't a top-level location, check for an alias.
1889 if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os)))
1890 return;
1892 TypeSwitch<LocationAttr>(loc)
1893 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1894 printLocationInternal(loc.getFallbackLocation(), pretty);
1896 .Case<UnknownLoc>([&](UnknownLoc loc) {
1897 if (pretty)
1898 os << "[unknown]";
1899 else
1900 os << "unknown";
1902 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1903 if (pretty)
1904 os << loc.getFilename().getValue();
1905 else
1906 printEscapedString(loc.getFilename());
1907 os << ':' << loc.getLine() << ':' << loc.getColumn();
1909 .Case<NameLoc>([&](NameLoc loc) {
1910 printEscapedString(loc.getName());
1912 // Print the child if it isn't unknown.
1913 auto childLoc = loc.getChildLoc();
1914 if (!llvm::isa<UnknownLoc>(childLoc)) {
1915 os << '(';
1916 printLocationInternal(childLoc, pretty);
1917 os << ')';
1920 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1921 Location caller = loc.getCaller();
1922 Location callee = loc.getCallee();
1923 if (!pretty)
1924 os << "callsite(";
1925 printLocationInternal(callee, pretty);
1926 if (pretty) {
1927 if (llvm::isa<NameLoc>(callee)) {
1928 if (llvm::isa<FileLineColLoc>(caller)) {
1929 os << " at ";
1930 } else {
1931 os << newLine << " at ";
1933 } else {
1934 os << newLine << " at ";
1936 } else {
1937 os << " at ";
1939 printLocationInternal(caller, pretty);
1940 if (!pretty)
1941 os << ")";
1943 .Case<FusedLoc>([&](FusedLoc loc) {
1944 if (!pretty)
1945 os << "fused";
1946 if (Attribute metadata = loc.getMetadata()) {
1947 os << '<';
1948 printAttribute(metadata);
1949 os << '>';
1951 os << '[';
1952 interleave(
1953 loc.getLocations(),
1954 [&](Location loc) { printLocationInternal(loc, pretty); },
1955 [&]() { os << ", "; });
1956 os << ']';
1960 /// Print a floating point value in a way that the parser will be able to
1961 /// round-trip losslessly.
1962 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1963 // We would like to output the FP constant value in exponential notation,
1964 // but we cannot do this if doing so will lose precision. Check here to
1965 // make sure that we only output it in exponential format if we can parse
1966 // the value back and get the same value.
1967 bool isInf = apValue.isInfinity();
1968 bool isNaN = apValue.isNaN();
1969 if (!isInf && !isNaN) {
1970 SmallString<128> strValue;
1971 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1972 /*TruncateZero=*/false);
1974 // Check to make sure that the stringized number is not some string like
1975 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
1976 // that the string matches the "[-+]?[0-9]" regex.
1977 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1978 ((strValue[0] == '-' || strValue[0] == '+') &&
1979 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1980 "[-+]?[0-9] regex does not match!");
1982 // Parse back the stringized version and check that the value is equal
1983 // (i.e., there is no precision loss).
1984 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1985 os << strValue;
1986 return;
1989 // If it is not, use the default format of APFloat instead of the
1990 // exponential notation.
1991 strValue.clear();
1992 apValue.toString(strValue);
1994 // Make sure that we can parse the default form as a float.
1995 if (strValue.str().contains('.')) {
1996 os << strValue;
1997 return;
2001 // Print special values in hexadecimal format. The sign bit should be included
2002 // in the literal.
2003 SmallVector<char, 16> str;
2004 APInt apInt = apValue.bitcastToAPInt();
2005 apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
2006 /*formatAsCLiteral=*/true);
2007 os << str;
2010 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
2011 if (printerFlags.shouldPrintDebugInfoPrettyForm())
2012 return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
2014 os << "loc(";
2015 if (!allowAlias || failed(printAlias(loc)))
2016 printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
2017 os << ')';
2020 void AsmPrinter::Impl::printResourceHandle(
2021 const AsmDialectResourceHandle &resource) {
2022 auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
2023 os << interface->getResourceKey(resource);
2024 state.getDialectResources()[resource.getDialect()].insert(resource);
2027 /// Returns true if the given dialect symbol data is simple enough to print in
2028 /// the pretty form. This is essentially when the symbol takes the form:
2029 /// identifier (`<` body `>`)?
2030 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
2031 // The name must start with an identifier.
2032 if (symName.empty() || !isalpha(symName.front()))
2033 return false;
2035 // Ignore all the characters that are valid in an identifier in the symbol
2036 // name.
2037 symName = symName.drop_while(
2038 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
2039 if (symName.empty())
2040 return true;
2042 // If we got to an unexpected character, then it must be a <>. Check that the
2043 // rest of the symbol is wrapped within <>.
2044 return symName.front() == '<' && symName.back() == '>';
2047 /// Print the given dialect symbol to the stream.
2048 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
2049 StringRef dialectName, StringRef symString) {
2050 os << symPrefix << dialectName;
2052 // If this symbol name is simple enough, print it directly in pretty form,
2053 // otherwise, we print it as an escaped string.
2054 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
2055 os << '.' << symString;
2056 return;
2059 os << '<' << symString << '>';
2062 /// Returns true if the given string can be represented as a bare identifier.
2063 static bool isBareIdentifier(StringRef name) {
2064 // By making this unsigned, the value passed in to isalnum will always be
2065 // in the range 0-255. This is important when building with MSVC because
2066 // its implementation will assert. This situation can arise when dealing
2067 // with UTF-8 multibyte characters.
2068 if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
2069 return false;
2070 return llvm::all_of(name.drop_front(), [](unsigned char c) {
2071 return isalnum(c) || c == '_' || c == '$' || c == '.';
2075 /// Print the given string as a keyword, or a quoted and escaped string if it
2076 /// has any special or non-printable characters in it.
2077 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
2078 // If it can be represented as a bare identifier, write it directly.
2079 if (isBareIdentifier(keyword)) {
2080 os << keyword;
2081 return;
2084 // Otherwise, output the keyword wrapped in quotes with proper escaping.
2085 os << "\"";
2086 printEscapedString(keyword, os);
2087 os << '"';
2090 /// Print the given string as a symbol reference. A symbol reference is
2091 /// represented as a string prefixed with '@'. The reference is surrounded with
2092 /// ""'s and escaped if it has any special or non-printable characters in it.
2093 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
2094 if (symbolRef.empty()) {
2095 os << "@<<INVALID EMPTY SYMBOL>>";
2096 return;
2098 os << '@';
2099 printKeywordOrString(symbolRef, os);
2102 // Print out a valid ElementsAttr that is succinct and can represent any
2103 // potential shape/type, for use when eliding a large ElementsAttr.
2105 // We choose to use a dense resource ElementsAttr literal with conspicuous
2106 // content to hopefully alert readers to the fact that this has been elided.
2107 static void printElidedElementsAttr(raw_ostream &os) {
2108 os << R"(dense_resource<__elided__>)";
2111 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
2112 return state.getAliasState().getAlias(attr, os);
2115 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
2116 return state.getAliasState().getAlias(type, os);
2119 void AsmPrinter::Impl::printAttribute(Attribute attr,
2120 AttrTypeElision typeElision) {
2121 if (!attr) {
2122 os << "<<NULL ATTRIBUTE>>";
2123 return;
2126 // Try to print an alias for this attribute.
2127 if (succeeded(printAlias(attr)))
2128 return;
2129 return printAttributeImpl(attr, typeElision);
2132 void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
2133 AttrTypeElision typeElision) {
2134 if (!isa<BuiltinDialect>(attr.getDialect())) {
2135 printDialectAttribute(attr);
2136 } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
2137 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
2138 opaqueAttr.getAttrData());
2139 } else if (llvm::isa<UnitAttr>(attr)) {
2140 os << "unit";
2141 return;
2142 } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2143 os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2144 if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) {
2145 printAttribute(distinctAttr.getReferencedAttr());
2147 os << '>';
2148 return;
2149 } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
2150 os << '{';
2151 interleaveComma(dictAttr.getValue(),
2152 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2153 os << '}';
2155 } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
2156 Type intType = intAttr.getType();
2157 if (intType.isSignlessInteger(1)) {
2158 os << (intAttr.getValue().getBoolValue() ? "true" : "false");
2160 // Boolean integer attributes always elides the type.
2161 return;
2164 // Only print attributes as unsigned if they are explicitly unsigned or are
2165 // signless 1-bit values. Indexes, signed values, and multi-bit signless
2166 // values print as signed.
2167 bool isUnsigned =
2168 intType.isUnsignedInteger() || intType.isSignlessInteger(1);
2169 intAttr.getValue().print(os, !isUnsigned);
2171 // IntegerAttr elides the type if I64.
2172 if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
2173 return;
2175 } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2176 printFloatValue(floatAttr.getValue(), os);
2178 // FloatAttr elides the type if F64.
2179 if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64())
2180 return;
2182 } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
2183 printEscapedString(strAttr.getValue());
2185 } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
2186 os << '[';
2187 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
2188 printAttribute(attr, AttrTypeElision::May);
2190 os << ']';
2192 } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
2193 os << "affine_map<";
2194 affineMapAttr.getValue().print(os);
2195 os << '>';
2197 // AffineMap always elides the type.
2198 return;
2200 } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
2201 os << "affine_set<";
2202 integerSetAttr.getValue().print(os);
2203 os << '>';
2205 // IntegerSet always elides the type.
2206 return;
2208 } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
2209 printType(typeAttr.getValue());
2211 } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
2212 printSymbolReference(refAttr.getRootReference().getValue(), os);
2213 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
2214 os << "::";
2215 printSymbolReference(nestedRef.getValue(), os);
2218 } else if (auto intOrFpEltAttr =
2219 llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
2220 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
2221 printElidedElementsAttr(os);
2222 } else {
2223 os << "dense<";
2224 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
2225 os << '>';
2228 } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
2229 if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
2230 printElidedElementsAttr(os);
2231 } else {
2232 os << "dense<";
2233 printDenseStringElementsAttr(strEltAttr);
2234 os << '>';
2237 } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
2238 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
2239 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
2240 printElidedElementsAttr(os);
2241 } else {
2242 os << "sparse<";
2243 DenseIntElementsAttr indices = sparseEltAttr.getIndices();
2244 if (indices.getNumElements() != 0) {
2245 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
2246 os << ", ";
2247 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
2249 os << '>';
2251 } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
2252 stridedLayoutAttr.print(os);
2253 } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
2254 os << "array<";
2255 printType(denseArrayAttr.getElementType());
2256 if (!denseArrayAttr.empty()) {
2257 os << ": ";
2258 printDenseArrayAttr(denseArrayAttr);
2260 os << ">";
2261 return;
2262 } else if (auto resourceAttr =
2263 llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
2264 os << "dense_resource<";
2265 printResourceHandle(resourceAttr.getRawHandle());
2266 os << ">";
2267 } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) {
2268 printLocation(locAttr);
2269 } else {
2270 llvm::report_fatal_error("Unknown builtin attribute");
2272 // Don't print the type if we must elide it, or if it is a None type.
2273 if (typeElision != AttrTypeElision::Must) {
2274 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
2275 Type attrType = typedAttr.getType();
2276 if (!llvm::isa<NoneType>(attrType)) {
2277 os << " : ";
2278 printType(attrType);
2284 /// Print the integer element of a DenseElementsAttr.
2285 static void printDenseIntElement(const APInt &value, raw_ostream &os,
2286 Type type) {
2287 if (type.isInteger(1))
2288 os << (value.getBoolValue() ? "true" : "false");
2289 else
2290 value.print(os, !type.isUnsignedInteger());
2293 static void
2294 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
2295 function_ref<void(unsigned)> printEltFn) {
2296 // Special case for 0-d and splat tensors.
2297 if (isSplat)
2298 return printEltFn(0);
2300 // Special case for degenerate tensors.
2301 auto numElements = type.getNumElements();
2302 if (numElements == 0)
2303 return;
2305 // We use a mixed-radix counter to iterate through the shape. When we bump a
2306 // non-least-significant digit, we emit a close bracket. When we next emit an
2307 // element we re-open all closed brackets.
2309 // The mixed-radix counter, with radices in 'shape'.
2310 int64_t rank = type.getRank();
2311 SmallVector<unsigned, 4> counter(rank, 0);
2312 // The number of brackets that have been opened and not closed.
2313 unsigned openBrackets = 0;
2315 auto shape = type.getShape();
2316 auto bumpCounter = [&] {
2317 // Bump the least significant digit.
2318 ++counter[rank - 1];
2319 // Iterate backwards bubbling back the increment.
2320 for (unsigned i = rank - 1; i > 0; --i)
2321 if (counter[i] >= shape[i]) {
2322 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
2323 counter[i] = 0;
2324 ++counter[i - 1];
2325 --openBrackets;
2326 os << ']';
2330 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
2331 if (idx != 0)
2332 os << ", ";
2333 while (openBrackets++ < rank)
2334 os << '[';
2335 openBrackets = rank;
2336 printEltFn(idx);
2337 bumpCounter();
2339 while (openBrackets-- > 0)
2340 os << ']';
2343 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
2344 bool allowHex) {
2345 if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
2346 return printDenseStringElementsAttr(stringAttr);
2348 printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
2349 allowHex);
2352 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
2353 DenseIntOrFPElementsAttr attr, bool allowHex) {
2354 auto type = attr.getType();
2355 auto elementType = type.getElementType();
2357 // Check to see if we should format this attribute as a hex string.
2358 auto numElements = type.getNumElements();
2359 if (!attr.isSplat() && allowHex &&
2360 shouldPrintElementsAttrWithHex(numElements)) {
2361 ArrayRef<char> rawData = attr.getRawData();
2362 if (llvm::support::endian::system_endianness() ==
2363 llvm::support::endianness::big) {
2364 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
2365 // machines. It is converted here to print in LE format.
2366 SmallVector<char, 64> outDataVec(rawData.size());
2367 MutableArrayRef<char> convRawData(outDataVec);
2368 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
2369 rawData, convRawData, type);
2370 printHexString(convRawData);
2371 } else {
2372 printHexString(rawData);
2375 return;
2378 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
2379 Type complexElementType = complexTy.getElementType();
2380 // Note: The if and else below had a common lambda function which invoked
2381 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2382 // and hence was replaced.
2383 if (llvm::isa<IntegerType>(complexElementType)) {
2384 auto valueIt = attr.value_begin<std::complex<APInt>>();
2385 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2386 auto complexValue = *(valueIt + index);
2387 os << "(";
2388 printDenseIntElement(complexValue.real(), os, complexElementType);
2389 os << ",";
2390 printDenseIntElement(complexValue.imag(), os, complexElementType);
2391 os << ")";
2393 } else {
2394 auto valueIt = attr.value_begin<std::complex<APFloat>>();
2395 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2396 auto complexValue = *(valueIt + index);
2397 os << "(";
2398 printFloatValue(complexValue.real(), os);
2399 os << ",";
2400 printFloatValue(complexValue.imag(), os);
2401 os << ")";
2404 } else if (elementType.isIntOrIndex()) {
2405 auto valueIt = attr.value_begin<APInt>();
2406 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2407 printDenseIntElement(*(valueIt + index), os, elementType);
2409 } else {
2410 assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
2411 auto valueIt = attr.value_begin<APFloat>();
2412 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2413 printFloatValue(*(valueIt + index), os);
2418 void AsmPrinter::Impl::printDenseStringElementsAttr(
2419 DenseStringElementsAttr attr) {
2420 ArrayRef<StringRef> data = attr.getRawStringData();
2421 auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
2422 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2425 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
2426 Type type = attr.getElementType();
2427 unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
2428 unsigned byteSize = bitwidth / 8;
2429 ArrayRef<char> data = attr.getRawData();
2431 auto printElementAt = [&](unsigned i) {
2432 APInt value(bitwidth, 0);
2433 if (bitwidth) {
2434 llvm::LoadIntFromMemory(
2435 value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
2436 byteSize);
2438 // Print the data as-is or as a float.
2439 if (type.isIntOrIndex()) {
2440 printDenseIntElement(value, getStream(), type);
2441 } else {
2442 APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value);
2443 printFloatValue(fltVal, getStream());
2446 llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
2447 printElementAt);
2450 void AsmPrinter::Impl::printType(Type type) {
2451 if (!type) {
2452 os << "<<NULL TYPE>>";
2453 return;
2456 // Try to print an alias for this type.
2457 if (succeeded(printAlias(type)))
2458 return;
2459 return printTypeImpl(type);
2462 void AsmPrinter::Impl::printTypeImpl(Type type) {
2463 TypeSwitch<Type>(type)
2464 .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2465 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2466 opaqueTy.getTypeData());
2468 .Case<IndexType>([&](Type) { os << "index"; })
2469 .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
2470 .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
2471 .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
2472 .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
2473 .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
2474 .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2475 .Case<Float16Type>([&](Type) { os << "f16"; })
2476 .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
2477 .Case<Float32Type>([&](Type) { os << "f32"; })
2478 .Case<Float64Type>([&](Type) { os << "f64"; })
2479 .Case<Float80Type>([&](Type) { os << "f80"; })
2480 .Case<Float128Type>([&](Type) { os << "f128"; })
2481 .Case<IntegerType>([&](IntegerType integerTy) {
2482 if (integerTy.isSigned())
2483 os << 's';
2484 else if (integerTy.isUnsigned())
2485 os << 'u';
2486 os << 'i' << integerTy.getWidth();
2488 .Case<FunctionType>([&](FunctionType funcTy) {
2489 os << '(';
2490 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2491 os << ") -> ";
2492 ArrayRef<Type> results = funcTy.getResults();
2493 if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
2494 printType(results[0]);
2495 } else {
2496 os << '(';
2497 interleaveComma(results, [&](Type ty) { printType(ty); });
2498 os << ')';
2501 .Case<VectorType>([&](VectorType vectorTy) {
2502 auto scalableDims = vectorTy.getScalableDims();
2503 os << "vector<";
2504 auto vShape = vectorTy.getShape();
2505 unsigned lastDim = vShape.size();
2506 unsigned dimIdx = 0;
2507 for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2508 if (!scalableDims.empty() && scalableDims[dimIdx])
2509 os << '[';
2510 os << vShape[dimIdx];
2511 if (!scalableDims.empty() && scalableDims[dimIdx])
2512 os << ']';
2513 os << 'x';
2515 printType(vectorTy.getElementType());
2516 os << '>';
2518 .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2519 os << "tensor<";
2520 for (int64_t dim : tensorTy.getShape()) {
2521 if (ShapedType::isDynamic(dim))
2522 os << '?';
2523 else
2524 os << dim;
2525 os << 'x';
2527 printType(tensorTy.getElementType());
2528 // Only print the encoding attribute value if set.
2529 if (tensorTy.getEncoding()) {
2530 os << ", ";
2531 printAttribute(tensorTy.getEncoding());
2533 os << '>';
2535 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2536 os << "tensor<*x";
2537 printType(tensorTy.getElementType());
2538 os << '>';
2540 .Case<MemRefType>([&](MemRefType memrefTy) {
2541 os << "memref<";
2542 for (int64_t dim : memrefTy.getShape()) {
2543 if (ShapedType::isDynamic(dim))
2544 os << '?';
2545 else
2546 os << dim;
2547 os << 'x';
2549 printType(memrefTy.getElementType());
2550 MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2551 if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
2552 os << ", ";
2553 printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2555 // Only print the memory space if it is the non-default one.
2556 if (memrefTy.getMemorySpace()) {
2557 os << ", ";
2558 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2560 os << '>';
2562 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2563 os << "memref<*x";
2564 printType(memrefTy.getElementType());
2565 // Only print the memory space if it is the non-default one.
2566 if (memrefTy.getMemorySpace()) {
2567 os << ", ";
2568 printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2570 os << '>';
2572 .Case<ComplexType>([&](ComplexType complexTy) {
2573 os << "complex<";
2574 printType(complexTy.getElementType());
2575 os << '>';
2577 .Case<TupleType>([&](TupleType tupleTy) {
2578 os << "tuple<";
2579 interleaveComma(tupleTy.getTypes(),
2580 [&](Type type) { printType(type); });
2581 os << '>';
2583 .Case<NoneType>([&](Type) { os << "none"; })
2584 .Default([&](Type type) { return printDialectType(type); });
2587 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2588 ArrayRef<StringRef> elidedAttrs,
2589 bool withKeyword) {
2590 // If there are no attributes, then there is nothing to be done.
2591 if (attrs.empty())
2592 return;
2594 // Functor used to print a filtered attribute list.
2595 auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2596 // Print the 'attributes' keyword if necessary.
2597 if (withKeyword)
2598 os << " attributes";
2600 // Otherwise, print them all out in braces.
2601 os << " {";
2602 interleaveComma(filteredAttrs,
2603 [&](NamedAttribute attr) { printNamedAttribute(attr); });
2604 os << '}';
2607 // If no attributes are elided, we can directly print with no filtering.
2608 if (elidedAttrs.empty())
2609 return printFilteredAttributesFn(attrs);
2611 // Otherwise, filter out any attributes that shouldn't be included.
2612 llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2613 elidedAttrs.end());
2614 auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2615 return !elidedAttrsSet.contains(attr.getName().strref());
2617 if (!filteredAttrs.empty())
2618 printFilteredAttributesFn(filteredAttrs);
2620 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2621 // Print the name without quotes if possible.
2622 ::printKeywordOrString(attr.getName().strref(), os);
2624 // Pretty printing elides the attribute value for unit attributes.
2625 if (llvm::isa<UnitAttr>(attr.getValue()))
2626 return;
2628 os << " = ";
2629 printAttribute(attr.getValue());
2632 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2633 auto &dialect = attr.getDialect();
2635 // Ask the dialect to serialize the attribute to a string.
2636 std::string attrName;
2638 llvm::raw_string_ostream attrNameStr(attrName);
2639 Impl subPrinter(attrNameStr, state);
2640 DialectAsmPrinter printer(subPrinter);
2641 dialect.printAttribute(attr, printer);
2643 printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2646 void AsmPrinter::Impl::printDialectType(Type type) {
2647 auto &dialect = type.getDialect();
2649 // Ask the dialect to serialize the type to a string.
2650 std::string typeName;
2652 llvm::raw_string_ostream typeNameStr(typeName);
2653 Impl subPrinter(typeNameStr, state);
2654 DialectAsmPrinter printer(subPrinter);
2655 dialect.printType(type, printer);
2657 printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2660 void AsmPrinter::Impl::printEscapedString(StringRef str) {
2661 os << "\"";
2662 llvm::printEscapedString(str, os);
2663 os << "\"";
2666 void AsmPrinter::Impl::printHexString(StringRef str) {
2667 os << "\"0x" << llvm::toHex(str) << "\"";
2669 void AsmPrinter::Impl::printHexString(ArrayRef<char> data) {
2670 printHexString(StringRef(data.data(), data.size()));
2673 //===--------------------------------------------------------------------===//
2674 // AsmPrinter
2675 //===--------------------------------------------------------------------===//
2677 AsmPrinter::~AsmPrinter() = default;
2679 raw_ostream &AsmPrinter::getStream() const {
2680 assert(impl && "expected AsmPrinter::getStream to be overriden");
2681 return impl->getStream();
2684 /// Print the given floating point value in a stablized form.
2685 void AsmPrinter::printFloat(const APFloat &value) {
2686 assert(impl && "expected AsmPrinter::printFloat to be overriden");
2687 printFloatValue(value, impl->getStream());
2690 void AsmPrinter::printType(Type type) {
2691 assert(impl && "expected AsmPrinter::printType to be overriden");
2692 impl->printType(type);
2695 void AsmPrinter::printAttribute(Attribute attr) {
2696 assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2697 impl->printAttribute(attr);
2700 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2701 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2702 return impl->printAlias(attr);
2705 LogicalResult AsmPrinter::printAlias(Type type) {
2706 assert(impl && "expected AsmPrinter::printAlias to be overriden");
2707 return impl->printAlias(type);
2710 void AsmPrinter::printAttributeWithoutType(Attribute attr) {
2711 assert(impl &&
2712 "expected AsmPrinter::printAttributeWithoutType to be overriden");
2713 impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2716 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2717 assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2718 ::printKeywordOrString(keyword, impl->getStream());
2721 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2722 assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2723 ::printSymbolReference(symbolRef, impl->getStream());
2726 void AsmPrinter::printResourceHandle(const AsmDialectResourceHandle &resource) {
2727 assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2728 impl->printResourceHandle(resource);
2731 //===----------------------------------------------------------------------===//
2732 // Affine expressions and maps
2733 //===----------------------------------------------------------------------===//
2735 void AsmPrinter::Impl::printAffineExpr(
2736 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2737 printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2740 void AsmPrinter::Impl::printAffineExprInternal(
2741 AffineExpr expr, BindingStrength enclosingTightness,
2742 function_ref<void(unsigned, bool)> printValueName) {
2743 const char *binopSpelling = nullptr;
2744 switch (expr.getKind()) {
2745 case AffineExprKind::SymbolId: {
2746 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
2747 if (printValueName)
2748 printValueName(pos, /*isSymbol=*/true);
2749 else
2750 os << 's' << pos;
2751 return;
2753 case AffineExprKind::DimId: {
2754 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
2755 if (printValueName)
2756 printValueName(pos, /*isSymbol=*/false);
2757 else
2758 os << 'd' << pos;
2759 return;
2761 case AffineExprKind::Constant:
2762 os << expr.cast<AffineConstantExpr>().getValue();
2763 return;
2764 case AffineExprKind::Add:
2765 binopSpelling = " + ";
2766 break;
2767 case AffineExprKind::Mul:
2768 binopSpelling = " * ";
2769 break;
2770 case AffineExprKind::FloorDiv:
2771 binopSpelling = " floordiv ";
2772 break;
2773 case AffineExprKind::CeilDiv:
2774 binopSpelling = " ceildiv ";
2775 break;
2776 case AffineExprKind::Mod:
2777 binopSpelling = " mod ";
2778 break;
2781 auto binOp = expr.cast<AffineBinaryOpExpr>();
2782 AffineExpr lhsExpr = binOp.getLHS();
2783 AffineExpr rhsExpr = binOp.getRHS();
2785 // Handle tightly binding binary operators.
2786 if (binOp.getKind() != AffineExprKind::Add) {
2787 if (enclosingTightness == BindingStrength::Strong)
2788 os << '(';
2790 // Pretty print multiplication with -1.
2791 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2792 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2793 rhsConst.getValue() == -1) {
2794 os << "-";
2795 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2796 if (enclosingTightness == BindingStrength::Strong)
2797 os << ')';
2798 return;
2801 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2803 os << binopSpelling;
2804 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2806 if (enclosingTightness == BindingStrength::Strong)
2807 os << ')';
2808 return;
2811 // Print out special "pretty" forms for add.
2812 if (enclosingTightness == BindingStrength::Strong)
2813 os << '(';
2815 // Pretty print addition to a product that has a negative operand as a
2816 // subtraction.
2817 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2818 if (rhs.getKind() == AffineExprKind::Mul) {
2819 AffineExpr rrhsExpr = rhs.getRHS();
2820 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2821 if (rrhs.getValue() == -1) {
2822 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2823 printValueName);
2824 os << " - ";
2825 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2826 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2827 printValueName);
2828 } else {
2829 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2830 printValueName);
2833 if (enclosingTightness == BindingStrength::Strong)
2834 os << ')';
2835 return;
2838 if (rrhs.getValue() < -1) {
2839 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2840 printValueName);
2841 os << " - ";
2842 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2843 printValueName);
2844 os << " * " << -rrhs.getValue();
2845 if (enclosingTightness == BindingStrength::Strong)
2846 os << ')';
2847 return;
2853 // Pretty print addition to a negative number as a subtraction.
2854 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2855 if (rhsConst.getValue() < 0) {
2856 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2857 os << " - " << -rhsConst.getValue();
2858 if (enclosingTightness == BindingStrength::Strong)
2859 os << ')';
2860 return;
2864 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2866 os << " + ";
2867 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2869 if (enclosingTightness == BindingStrength::Strong)
2870 os << ')';
2873 void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
2874 printAffineExprInternal(expr, BindingStrength::Weak);
2875 isEq ? os << " == 0" : os << " >= 0";
2878 void AsmPrinter::Impl::printAffineMap(AffineMap map) {
2879 // Dimension identifiers.
2880 os << '(';
2881 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2882 os << 'd' << i << ", ";
2883 if (map.getNumDims() >= 1)
2884 os << 'd' << map.getNumDims() - 1;
2885 os << ')';
2887 // Symbolic identifiers.
2888 if (map.getNumSymbols() != 0) {
2889 os << '[';
2890 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2891 os << 's' << i << ", ";
2892 if (map.getNumSymbols() >= 1)
2893 os << 's' << map.getNumSymbols() - 1;
2894 os << ']';
2897 // Result affine expressions.
2898 os << " -> (";
2899 interleaveComma(map.getResults(),
2900 [&](AffineExpr expr) { printAffineExpr(expr); });
2901 os << ')';
2904 void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
2905 // Dimension identifiers.
2906 os << '(';
2907 for (unsigned i = 1; i < set.getNumDims(); ++i)
2908 os << 'd' << i - 1 << ", ";
2909 if (set.getNumDims() >= 1)
2910 os << 'd' << set.getNumDims() - 1;
2911 os << ')';
2913 // Symbolic identifiers.
2914 if (set.getNumSymbols() != 0) {
2915 os << '[';
2916 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2917 os << 's' << i << ", ";
2918 if (set.getNumSymbols() >= 1)
2919 os << 's' << set.getNumSymbols() - 1;
2920 os << ']';
2923 // Print constraints.
2924 os << " : (";
2925 int numConstraints = set.getNumConstraints();
2926 for (int i = 1; i < numConstraints; ++i) {
2927 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2928 os << ", ";
2930 if (numConstraints >= 1)
2931 printAffineConstraint(set.getConstraint(numConstraints - 1),
2932 set.isEq(numConstraints - 1));
2933 os << ')';
2936 //===----------------------------------------------------------------------===//
2937 // OperationPrinter
2938 //===----------------------------------------------------------------------===//
2940 namespace {
2941 /// This class contains the logic for printing operations, regions, and blocks.
2942 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
2943 public:
2944 using Impl = AsmPrinter::Impl;
2945 using Impl::printType;
2947 explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
2948 : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
2950 /// Print the given top-level operation.
2951 void printTopLevelOperation(Operation *op);
2953 /// Print the given operation, including its left-hand side and its right-hand
2954 /// side, with its indent and location.
2955 void printFullOpWithIndentAndLoc(Operation *op);
2956 /// Print the given operation, including its left-hand side and its right-hand
2957 /// side, but not including indentation and location.
2958 void printFullOp(Operation *op);
2959 /// Print the right-hand size of the given operation in the custom or generic
2960 /// form.
2961 void printCustomOrGenericOp(Operation *op) override;
2962 /// Print the right-hand side of the given operation in the generic form.
2963 void printGenericOp(Operation *op, bool printOpName) override;
2965 /// Print the name of the given block.
2966 void printBlockName(Block *block);
2968 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2969 /// block are not printed. If 'printBlockTerminator' is false, the terminator
2970 /// operation of the block is not printed.
2971 void print(Block *block, bool printBlockArgs = true,
2972 bool printBlockTerminator = true);
2974 /// Print the ID of the given value, optionally with its result number.
2975 void printValueID(Value value, bool printResultNo = true,
2976 raw_ostream *streamOverride = nullptr) const;
2978 /// Print the ID of the given operation.
2979 void printOperationID(Operation *op,
2980 raw_ostream *streamOverride = nullptr) const;
2982 //===--------------------------------------------------------------------===//
2983 // OpAsmPrinter methods
2984 //===--------------------------------------------------------------------===//
2986 /// Print a loc(...) specifier if printing debug info is enabled. Locations
2987 /// may be deferred with an alias.
2988 void printOptionalLocationSpecifier(Location loc) override {
2989 printTrailingLocation(loc);
2992 /// Print a newline and indent the printer to the start of the current
2993 /// operation.
2994 void printNewline() override {
2995 os << newLine;
2996 os.indent(currentIndent);
2999 /// Increase indentation.
3000 void increaseIndent() override { currentIndent += indentWidth; }
3002 /// Decrease indentation.
3003 void decreaseIndent() override { currentIndent -= indentWidth; }
3005 /// Print a block argument in the usual format of:
3006 /// %ssaName : type {attr1=42} loc("here")
3007 /// where location printing is controlled by the standard internal option.
3008 /// You may pass omitType=true to not print a type, and pass an empty
3009 /// attribute list if you don't care for attributes.
3010 void printRegionArgument(BlockArgument arg,
3011 ArrayRef<NamedAttribute> argAttrs = {},
3012 bool omitType = false) override;
3014 /// Print the ID for the given value.
3015 void printOperand(Value value) override { printValueID(value); }
3016 void printOperand(Value value, raw_ostream &os) override {
3017 printValueID(value, /*printResultNo=*/true, &os);
3020 /// Print an optional attribute dictionary with a given set of elided values.
3021 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3022 ArrayRef<StringRef> elidedAttrs = {}) override {
3023 Impl::printOptionalAttrDict(attrs, elidedAttrs);
3025 void printOptionalAttrDictWithKeyword(
3026 ArrayRef<NamedAttribute> attrs,
3027 ArrayRef<StringRef> elidedAttrs = {}) override {
3028 Impl::printOptionalAttrDict(attrs, elidedAttrs,
3029 /*withKeyword=*/true);
3032 /// Print the given successor.
3033 void printSuccessor(Block *successor) override;
3035 /// Print an operation successor with the operands used for the block
3036 /// arguments.
3037 void printSuccessorAndUseList(Block *successor,
3038 ValueRange succOperands) override;
3040 /// Print the given region.
3041 void printRegion(Region &region, bool printEntryBlockArgs,
3042 bool printBlockTerminators, bool printEmptyBlock) override;
3044 /// Renumber the arguments for the specified region to the same names as the
3045 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
3046 /// operations. If any entry in namesToUse is null, the corresponding
3047 /// argument name is left alone.
3048 void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
3049 state.getSSANameState().shadowRegionArgs(region, namesToUse);
3052 /// Print the given affine map with the symbol and dimension operands printed
3053 /// inline with the map.
3054 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3055 ValueRange operands) override;
3057 /// Print the given affine expression with the symbol and dimension operands
3058 /// printed inline with the expression.
3059 void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
3060 ValueRange symOperands) override;
3062 /// Print users of this operation or id of this operation if it has no result.
3063 void printUsersComment(Operation *op);
3065 /// Print users of this block arg.
3066 void printUsersComment(BlockArgument arg);
3068 /// Print the users of a value.
3069 void printValueUsers(Value value);
3071 /// Print either the ids of the result values or the id of the operation if
3072 /// the operation has no results.
3073 void printUserIDs(Operation *user, bool prefixComma = false);
3075 private:
3076 /// This class represents a resource builder implementation for the MLIR
3077 /// textual assembly format.
3078 class ResourceBuilder : public AsmResourceBuilder {
3079 public:
3080 using ValueFn = function_ref<void(raw_ostream &)>;
3081 using PrintFn = function_ref<void(StringRef, ValueFn)>;
3083 ResourceBuilder(OperationPrinter &p, PrintFn printFn)
3084 : p(p), printFn(printFn) {}
3085 ~ResourceBuilder() override = default;
3087 void buildBool(StringRef key, bool data) final {
3088 printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); });
3091 void buildString(StringRef key, StringRef data) final {
3092 printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); });
3095 void buildBlob(StringRef key, ArrayRef<char> data,
3096 uint32_t dataAlignment) final {
3097 printFn(key, [&](raw_ostream &os) {
3098 // Store the blob in a hex string containing the alignment and the data.
3099 llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
3100 os << "\"0x"
3101 << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
3102 sizeof(dataAlignment)))
3103 << llvm::toHex(StringRef(data.data(), data.size())) << "\"";
3107 private:
3108 OperationPrinter &p;
3109 PrintFn printFn;
3112 /// Print the metadata dictionary for the file, eliding it if it is empty.
3113 void printFileMetadataDictionary(Operation *op);
3115 /// Print the resource sections for the file metadata dictionary.
3116 /// `checkAddMetadataDict` is used to indicate that metadata is going to be
3117 /// added, and the file metadata dictionary should be started if it hasn't
3118 /// yet.
3119 void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
3120 Operation *op);
3122 // Contains the stack of default dialects to use when printing regions.
3123 // A new dialect is pushed to the stack before parsing regions nested under an
3124 // operation implementing `OpAsmOpInterface`, and popped when done. At the
3125 // top-level we start with "builtin" as the default, so that the top-level
3126 // `module` operation prints as-is.
3127 SmallVector<StringRef> defaultDialectStack{"builtin"};
3129 /// The number of spaces used for indenting nested operations.
3130 const static unsigned indentWidth = 2;
3132 // This is the current indentation level for nested structures.
3133 unsigned currentIndent = 0;
3135 } // namespace
3137 void OperationPrinter::printTopLevelOperation(Operation *op) {
3138 // Output the aliases at the top level that can't be deferred.
3139 state.getAliasState().printNonDeferredAliases(*this, newLine);
3141 // Print the module.
3142 printFullOpWithIndentAndLoc(op);
3143 os << newLine;
3145 // Output the aliases at the top level that can be deferred.
3146 state.getAliasState().printDeferredAliases(*this, newLine);
3148 // Output any file level metadata.
3149 printFileMetadataDictionary(op);
3152 void OperationPrinter::printFileMetadataDictionary(Operation *op) {
3153 bool sawMetadataEntry = false;
3154 auto checkAddMetadataDict = [&] {
3155 if (!std::exchange(sawMetadataEntry, true))
3156 os << newLine << "{-#" << newLine;
3159 // Add the various types of metadata.
3160 printResourceFileMetadata(checkAddMetadataDict, op);
3162 // If the file dictionary exists, close it.
3163 if (sawMetadataEntry)
3164 os << newLine << "#-}" << newLine;
3167 void OperationPrinter::printResourceFileMetadata(
3168 function_ref<void()> checkAddMetadataDict, Operation *op) {
3169 // Functor used to add data entries to the file metadata dictionary.
3170 bool hadResource = false;
3171 bool needResourceComma = false;
3172 bool needEntryComma = false;
3173 auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
3174 auto &&...providerArgs) {
3175 bool hadEntry = false;
3176 auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
3177 checkAddMetadataDict();
3179 // Emit the top-level resource entry if we haven't yet.
3180 if (!std::exchange(hadResource, true)) {
3181 if (needResourceComma)
3182 os << "," << newLine;
3183 os << " " << dictName << "_resources: {" << newLine;
3185 // Emit the parent resource entry if we haven't yet.
3186 if (!std::exchange(hadEntry, true)) {
3187 if (needEntryComma)
3188 os << "," << newLine;
3189 os << " " << name << ": {" << newLine;
3190 } else {
3191 os << "," << newLine;
3194 os << " " << key << ": ";
3195 valueFn(os);
3197 ResourceBuilder entryBuilder(*this, printFn);
3198 provider.buildResources(op, providerArgs..., entryBuilder);
3200 needEntryComma |= hadEntry;
3201 if (hadEntry)
3202 os << newLine << " }";
3205 // Print the `dialect_resources` section if we have any dialects with
3206 // resources.
3207 for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
3208 auto &dialectResources = state.getDialectResources();
3209 StringRef name = interface.getDialect()->getNamespace();
3210 auto it = dialectResources.find(interface.getDialect());
3211 if (it != dialectResources.end())
3212 processProvider("dialect", name, interface, it->second);
3213 else
3214 processProvider("dialect", name, interface,
3215 SetVector<AsmDialectResourceHandle>());
3217 if (hadResource)
3218 os << newLine << " }";
3220 // Print the `external_resources` section if we have any external clients with
3221 // resources.
3222 needEntryComma = false;
3223 needResourceComma = hadResource;
3224 hadResource = false;
3225 for (const auto &printer : state.getResourcePrinters())
3226 processProvider("external", printer.getName(), printer);
3227 if (hadResource)
3228 os << newLine << " }";
3231 /// Print a block argument in the usual format of:
3232 /// %ssaName : type {attr1=42} loc("here")
3233 /// where location printing is controlled by the standard internal option.
3234 /// You may pass omitType=true to not print a type, and pass an empty
3235 /// attribute list if you don't care for attributes.
3236 void OperationPrinter::printRegionArgument(BlockArgument arg,
3237 ArrayRef<NamedAttribute> argAttrs,
3238 bool omitType) {
3239 printOperand(arg);
3240 if (!omitType) {
3241 os << ": ";
3242 printType(arg.getType());
3244 printOptionalAttrDict(argAttrs);
3245 // TODO: We should allow location aliases on block arguments.
3246 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3249 void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
3250 // Track the location of this operation.
3251 state.registerOperationLocation(op, newLine.curLine, currentIndent);
3253 os.indent(currentIndent);
3254 printFullOp(op);
3255 printTrailingLocation(op->getLoc());
3256 if (printerFlags.shouldPrintValueUsers())
3257 printUsersComment(op);
3260 void OperationPrinter::printFullOp(Operation *op) {
3261 if (size_t numResults = op->getNumResults()) {
3262 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
3263 printValueID(op->getResult(resultNo), /*printResultNo=*/false);
3264 if (resultCount > 1)
3265 os << ':' << resultCount;
3268 // Check to see if this operation has multiple result groups.
3269 ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
3270 if (!resultGroups.empty()) {
3271 // Interleave the groups excluding the last one, this one will be handled
3272 // separately.
3273 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
3274 printResultGroup(resultGroups[i],
3275 resultGroups[i + 1] - resultGroups[i]);
3277 os << ", ";
3278 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
3280 } else {
3281 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
3284 os << " = ";
3287 printCustomOrGenericOp(op);
3290 void OperationPrinter::printUsersComment(Operation *op) {
3291 unsigned numResults = op->getNumResults();
3292 if (!numResults && op->getNumOperands()) {
3293 os << " // id: ";
3294 printOperationID(op);
3295 } else if (numResults && op->use_empty()) {
3296 os << " // unused";
3297 } else if (numResults && !op->use_empty()) {
3298 // Print "user" if the operation has one result used to compute one other
3299 // result, or is used in one operation with no result.
3300 unsigned usedInNResults = 0;
3301 unsigned usedInNOperations = 0;
3302 SmallPtrSet<Operation *, 1> userSet;
3303 for (Operation *user : op->getUsers()) {
3304 if (userSet.insert(user).second) {
3305 ++usedInNOperations;
3306 usedInNResults += user->getNumResults();
3310 // We already know that users is not empty.
3311 bool exactlyOneUniqueUse =
3312 usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
3313 os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
3314 bool shouldPrintBrackets = numResults > 1;
3315 auto printOpResult = [&](OpResult opResult) {
3316 if (shouldPrintBrackets)
3317 os << "(";
3318 printValueUsers(opResult);
3319 if (shouldPrintBrackets)
3320 os << ")";
3323 interleaveComma(op->getResults(), printOpResult);
3327 void OperationPrinter::printUsersComment(BlockArgument arg) {
3328 os << "// ";
3329 printValueID(arg);
3330 if (arg.use_empty()) {
3331 os << " is unused";
3332 } else {
3333 os << " is used by ";
3334 printValueUsers(arg);
3336 os << newLine;
3339 void OperationPrinter::printValueUsers(Value value) {
3340 if (value.use_empty())
3341 os << "unused";
3343 // One value might be used as the operand of an operation more than once.
3344 // Only print the operations results once in that case.
3345 SmallPtrSet<Operation *, 1> userSet;
3346 for (auto [index, user] : enumerate(value.getUsers())) {
3347 if (userSet.insert(user).second)
3348 printUserIDs(user, index);
3352 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
3353 if (prefixComma)
3354 os << ", ";
3356 if (!user->getNumResults()) {
3357 printOperationID(user);
3358 } else {
3359 interleaveComma(user->getResults(),
3360 [this](Value result) { printValueID(result); });
3364 void OperationPrinter::printCustomOrGenericOp(Operation *op) {
3365 // If requested, always print the generic form.
3366 if (!printerFlags.shouldPrintGenericOpForm()) {
3367 // Check to see if this is a known operation. If so, use the registered
3368 // custom printer hook.
3369 if (auto opInfo = op->getRegisteredInfo()) {
3370 opInfo->printAssembly(op, *this, defaultDialectStack.back());
3371 return;
3373 // Otherwise try to dispatch to the dialect, if available.
3374 if (Dialect *dialect = op->getDialect()) {
3375 if (auto opPrinter = dialect->getOperationPrinter(op)) {
3376 // Print the op name first.
3377 StringRef name = op->getName().getStringRef();
3378 // Only drop the default dialect prefix when it cannot lead to
3379 // ambiguities.
3380 if (name.count('.') == 1)
3381 name.consume_front((defaultDialectStack.back() + ".").str());
3382 os << name;
3384 // Print the rest of the op now.
3385 opPrinter(op, *this);
3386 return;
3391 // Otherwise print with the generic assembly form.
3392 printGenericOp(op, /*printOpName=*/true);
3395 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
3396 if (printOpName)
3397 printEscapedString(op->getName().getStringRef());
3398 os << '(';
3399 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
3400 os << ')';
3402 // For terminators, print the list of successors and their operands.
3403 if (op->getNumSuccessors() != 0) {
3404 os << '[';
3405 interleaveComma(op->getSuccessors(),
3406 [&](Block *successor) { printBlockName(successor); });
3407 os << ']';
3410 // Print the properties.
3411 if (Attribute prop = op->getPropertiesAsAttribute()) {
3412 os << " <";
3413 Impl::printAttribute(prop);
3414 os << '>';
3417 // Print regions.
3418 if (op->getNumRegions() != 0) {
3419 os << " (";
3420 interleaveComma(op->getRegions(), [&](Region &region) {
3421 printRegion(region, /*printEntryBlockArgs=*/true,
3422 /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
3424 os << ')';
3427 auto attrs = op->getDiscardableAttrs();
3428 printOptionalAttrDict(attrs);
3430 // Print the type signature of the operation.
3431 os << " : ";
3432 printFunctionalType(op);
3435 void OperationPrinter::printBlockName(Block *block) {
3436 os << state.getSSANameState().getBlockInfo(block).name;
3439 void OperationPrinter::print(Block *block, bool printBlockArgs,
3440 bool printBlockTerminator) {
3441 // Print the block label and argument list if requested.
3442 if (printBlockArgs) {
3443 os.indent(currentIndent);
3444 printBlockName(block);
3446 // Print the argument list if non-empty.
3447 if (!block->args_empty()) {
3448 os << '(';
3449 interleaveComma(block->getArguments(), [&](BlockArgument arg) {
3450 printValueID(arg);
3451 os << ": ";
3452 printType(arg.getType());
3453 // TODO: We should allow location aliases on block arguments.
3454 printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3456 os << ')';
3458 os << ':';
3460 // Print out some context information about the predecessors of this block.
3461 if (!block->getParent()) {
3462 os << " // block is not in a region!";
3463 } else if (block->hasNoPredecessors()) {
3464 if (!block->isEntryBlock())
3465 os << " // no predecessors";
3466 } else if (auto *pred = block->getSinglePredecessor()) {
3467 os << " // pred: ";
3468 printBlockName(pred);
3469 } else {
3470 // We want to print the predecessors in a stable order, not in
3471 // whatever order the use-list is in, so gather and sort them.
3472 SmallVector<BlockInfo, 4> predIDs;
3473 for (auto *pred : block->getPredecessors())
3474 predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
3475 llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
3476 return lhs.ordering < rhs.ordering;
3479 os << " // " << predIDs.size() << " preds: ";
3481 interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
3483 os << newLine;
3486 currentIndent += indentWidth;
3488 if (printerFlags.shouldPrintValueUsers()) {
3489 for (BlockArgument arg : block->getArguments()) {
3490 os.indent(currentIndent);
3491 printUsersComment(arg);
3495 bool hasTerminator =
3496 !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3497 auto range = llvm::make_range(
3498 block->begin(),
3499 std::prev(block->end(),
3500 (!hasTerminator || printBlockTerminator) ? 0 : 1));
3501 for (auto &op : range) {
3502 printFullOpWithIndentAndLoc(&op);
3503 os << newLine;
3505 currentIndent -= indentWidth;
3508 void OperationPrinter::printValueID(Value value, bool printResultNo,
3509 raw_ostream *streamOverride) const {
3510 state.getSSANameState().printValueID(value, printResultNo,
3511 streamOverride ? *streamOverride : os);
3514 void OperationPrinter::printOperationID(Operation *op,
3515 raw_ostream *streamOverride) const {
3516 state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
3517 : os);
3520 void OperationPrinter::printSuccessor(Block *successor) {
3521 printBlockName(successor);
3524 void OperationPrinter::printSuccessorAndUseList(Block *successor,
3525 ValueRange succOperands) {
3526 printBlockName(successor);
3527 if (succOperands.empty())
3528 return;
3530 os << '(';
3531 interleaveComma(succOperands,
3532 [this](Value operand) { printValueID(operand); });
3533 os << " : ";
3534 interleaveComma(succOperands,
3535 [this](Value operand) { printType(operand.getType()); });
3536 os << ')';
3539 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3540 bool printBlockTerminators,
3541 bool printEmptyBlock) {
3542 if (printerFlags.shouldSkipRegions()) {
3543 os << "{...}";
3544 return;
3546 os << "{" << newLine;
3547 if (!region.empty()) {
3548 auto restoreDefaultDialect =
3549 llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
3550 if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3551 defaultDialectStack.push_back(iface.getDefaultDialect());
3552 else
3553 defaultDialectStack.push_back("");
3555 auto *entryBlock = &region.front();
3556 // Force printing the block header if printEmptyBlock is set and the block
3557 // is empty or if printEntryBlockArgs is set and there are arguments to
3558 // print.
3559 bool shouldAlwaysPrintBlockHeader =
3560 (printEmptyBlock && entryBlock->empty()) ||
3561 (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3562 print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
3563 for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
3564 print(&b);
3566 os.indent(currentIndent) << "}";
3569 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3570 ValueRange operands) {
3571 if (!mapAttr) {
3572 os << "<<NULL AFFINE MAP>>";
3573 return;
3575 AffineMap map = mapAttr.getValue();
3576 unsigned numDims = map.getNumDims();
3577 auto printValueName = [&](unsigned pos, bool isSymbol) {
3578 unsigned index = isSymbol ? numDims + pos : pos;
3579 assert(index < operands.size());
3580 if (isSymbol)
3581 os << "symbol(";
3582 printValueID(operands[index]);
3583 if (isSymbol)
3584 os << ')';
3587 interleaveComma(map.getResults(), [&](AffineExpr expr) {
3588 printAffineExpr(expr, printValueName);
3592 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3593 ValueRange dimOperands,
3594 ValueRange symOperands) {
3595 auto printValueName = [&](unsigned pos, bool isSymbol) {
3596 if (!isSymbol)
3597 return printValueID(dimOperands[pos]);
3598 os << "symbol(";
3599 printValueID(symOperands[pos]);
3600 os << ')';
3602 printAffineExpr(expr, printValueName);
3605 //===----------------------------------------------------------------------===//
3606 // print and dump methods
3607 //===----------------------------------------------------------------------===//
3609 void Attribute::print(raw_ostream &os, bool elideType) const {
3610 if (!*this) {
3611 os << "<<NULL ATTRIBUTE>>";
3612 return;
3615 AsmState state(getContext());
3616 print(os, state, elideType);
3618 void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
3619 using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
3620 AsmPrinter::Impl(os, state.getImpl())
3621 .printAttribute(*this, elideType ? AttrTypeElision::Must
3622 : AttrTypeElision::Never);
3625 void Attribute::dump() const {
3626 print(llvm::errs());
3627 llvm::errs() << "\n";
3630 void Type::print(raw_ostream &os) const {
3631 if (!*this) {
3632 os << "<<NULL TYPE>>";
3633 return;
3636 AsmState state(getContext());
3637 print(os, state);
3639 void Type::print(raw_ostream &os, AsmState &state) const {
3640 AsmPrinter::Impl(os, state.getImpl()).printType(*this);
3643 void Type::dump() const {
3644 print(llvm::errs());
3645 llvm::errs() << "\n";
3648 void AffineMap::dump() const {
3649 print(llvm::errs());
3650 llvm::errs() << "\n";
3653 void IntegerSet::dump() const {
3654 print(llvm::errs());
3655 llvm::errs() << "\n";
3658 void AffineExpr::print(raw_ostream &os) const {
3659 if (!expr) {
3660 os << "<<NULL AFFINE EXPR>>";
3661 return;
3663 AsmState state(getContext());
3664 AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
3667 void AffineExpr::dump() const {
3668 print(llvm::errs());
3669 llvm::errs() << "\n";
3672 void AffineMap::print(raw_ostream &os) const {
3673 if (!map) {
3674 os << "<<NULL AFFINE MAP>>";
3675 return;
3677 AsmState state(getContext());
3678 AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
3681 void IntegerSet::print(raw_ostream &os) const {
3682 AsmState state(getContext());
3683 AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
3686 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
3687 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
3688 if (!impl) {
3689 os << "<<NULL VALUE>>";
3690 return;
3693 if (auto *op = getDefiningOp())
3694 return op->print(os, flags);
3695 // TODO: Improve BlockArgument print'ing.
3696 BlockArgument arg = llvm::cast<BlockArgument>(*this);
3697 os << "<block argument> of type '" << arg.getType()
3698 << "' at index: " << arg.getArgNumber();
3700 void Value::print(raw_ostream &os, AsmState &state) {
3701 if (!impl) {
3702 os << "<<NULL VALUE>>";
3703 return;
3706 if (auto *op = getDefiningOp())
3707 return op->print(os, state);
3709 // TODO: Improve BlockArgument print'ing.
3710 BlockArgument arg = llvm::cast<BlockArgument>(*this);
3711 os << "<block argument> of type '" << arg.getType()
3712 << "' at index: " << arg.getArgNumber();
3715 void Value::dump() {
3716 print(llvm::errs());
3717 llvm::errs() << "\n";
3720 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
3721 // TODO: This doesn't necessarily capture all potential cases.
3722 // Currently, region arguments can be shadowed when printing the main
3723 // operation. If the IR hasn't been printed, this will produce the old SSA
3724 // name and not the shadowed name.
3725 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
3726 os);
3729 static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
3730 do {
3731 // If we are printing local scope, stop at the first operation that is
3732 // isolated from above.
3733 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3734 break;
3736 // Otherwise, traverse up to the next parent.
3737 Operation *parentOp = op->getParentOp();
3738 if (!parentOp)
3739 break;
3740 op = parentOp;
3741 } while (true);
3742 return op;
3745 void Value::printAsOperand(raw_ostream &os, const OpPrintingFlags &flags) {
3746 Operation *op;
3747 if (auto result = llvm::dyn_cast<OpResult>(*this)) {
3748 op = result.getOwner();
3749 } else {
3750 op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
3751 if (!op) {
3752 os << "<<UNKNOWN SSA VALUE>>";
3753 return;
3756 op = findParent(op, flags.shouldUseLocalScope());
3757 AsmState state(op, flags);
3758 printAsOperand(os, state);
3761 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
3762 // Find the operation to number from based upon the provided flags.
3763 Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
3764 AsmState state(op, printerFlags);
3765 print(os, state);
3767 void Operation::print(raw_ostream &os, AsmState &state) {
3768 OperationPrinter printer(os, state.getImpl());
3769 if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
3770 state.getImpl().initializeAliases(this);
3771 printer.printTopLevelOperation(this);
3772 } else {
3773 printer.printFullOpWithIndentAndLoc(this);
3777 void Operation::dump() {
3778 print(llvm::errs(), OpPrintingFlags().useLocalScope());
3779 llvm::errs() << "\n";
3782 void Block::print(raw_ostream &os) {
3783 Operation *parentOp = getParentOp();
3784 if (!parentOp) {
3785 os << "<<UNLINKED BLOCK>>\n";
3786 return;
3788 // Get the top-level op.
3789 while (auto *nextOp = parentOp->getParentOp())
3790 parentOp = nextOp;
3792 AsmState state(parentOp);
3793 print(os, state);
3795 void Block::print(raw_ostream &os, AsmState &state) {
3796 OperationPrinter(os, state.getImpl()).print(this);
3799 void Block::dump() { print(llvm::errs()); }
3801 /// Print out the name of the block without printing its body.
3802 void Block::printAsOperand(raw_ostream &os, bool printType) {
3803 Operation *parentOp = getParentOp();
3804 if (!parentOp) {
3805 os << "<<UNLINKED BLOCK>>\n";
3806 return;
3808 AsmState state(parentOp);
3809 printAsOperand(os, state);
3811 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
3812 OperationPrinter printer(os, state.getImpl());
3813 printer.printBlockName(this);