TargetParser: AArch64: Add part numbers for Apple CPUs.
[llvm-project.git] / mlir / lib / Transforms / ViewOpGraph.cpp
blobfa0af7665ba4c4cca9eb69b8df5b9f56aafa107e
1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/ViewOpGraph.h"
11 #include "mlir/Analysis/TopologicalSortUtils.h"
12 #include "mlir/IR/Block.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
19 #include <map>
20 #include <optional>
21 #include <utility>
23 namespace mlir {
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
26 } // namespace mlir
28 using namespace mlir;
30 static const StringRef kLineStyleControlFlow = "dashed";
31 static const StringRef kLineStyleDataFlow = "solid";
32 static const StringRef kShapeNode = "ellipse";
33 static const StringRef kShapeNone = "plain";
35 /// Return the size limits for eliding large attributes.
36 static int64_t getLargeAttributeSizeLimit() {
37 // Use the default from the printer flags if possible.
38 if (std::optional<int64_t> limit =
39 OpPrintingFlags().getLargeElementsAttrLimit())
40 return *limit;
41 return 16;
44 /// Return all values printed onto a stream as a string.
45 static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
46 std::string buf;
47 llvm::raw_string_ostream os(buf);
48 func(os);
49 return buf;
52 /// Escape special characters such as '\n' and quotation marks.
53 static std::string escapeString(std::string str) {
54 return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
57 /// Put quotation marks around a given string.
58 static std::string quoteString(const std::string &str) {
59 return "\"" + str + "\"";
62 using AttributeMap = std::map<std::string, std::string>;
64 namespace {
66 /// This struct represents a node in the DOT language. Each node has an
67 /// identifier and an optional identifier for the cluster (subgraph) that
68 /// contains the node.
69 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70 /// not between clusters. However, edges can be clipped to the boundary of a
71 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72 /// cluster, an invisible "anchor" node is created.
73 struct Node {
74 public:
75 Node(int id = 0, std::optional<int> clusterId = std::nullopt)
76 : id(id), clusterId(clusterId) {}
78 int id;
79 std::optional<int> clusterId;
82 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84 /// about the Graphviz DOT language.
85 class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
86 public:
87 PrintOpPass(raw_ostream &os) : os(os) {}
88 PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
90 void runOnOperation() override {
91 initColorMapping(*getOperation());
92 emitGraph([&]() {
93 processOperation(getOperation());
94 emitAllEdgeStmts();
95 });
96 markAllAnalysesPreserved();
99 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
100 void emitRegionCFG(Region &region) {
101 printControlFlowEdges = true;
102 printDataFlowEdges = false;
103 initColorMapping(region);
104 emitGraph([&]() { processRegion(region); });
107 private:
108 /// Generate a color mapping that will color every operation with the same
109 /// name the same way. It'll interpolate the hue in the HSV color-space,
110 /// attempting to keep the contrast suitable for black text.
111 template <typename T>
112 void initColorMapping(T &irEntity) {
113 backgroundColors.clear();
114 SmallVector<Operation *> ops;
115 irEntity.walk([&](Operation *op) {
116 auto &entry = backgroundColors[op->getName()];
117 if (entry.first == 0)
118 ops.push_back(op);
119 ++entry.first;
121 for (auto indexedOps : llvm::enumerate(ops)) {
122 double hue = ((double)indexedOps.index()) / ops.size();
123 backgroundColors[indexedOps.value()->getName()].second =
124 std::to_string(hue) + " 1.0 1.0";
128 /// Emit all edges. This function should be called after all nodes have been
129 /// emitted.
130 void emitAllEdgeStmts() {
131 if (printDataFlowEdges) {
132 for (const auto &[value, node, label] : dataFlowEdges) {
133 emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
137 for (const std::string &edge : edges)
138 os << edge << ";\n";
139 edges.clear();
142 /// Emit a cluster (subgraph). The specified builder generates the body of the
143 /// cluster. Return the anchor node of the cluster.
144 Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
145 int clusterId = ++counter;
146 os << "subgraph cluster_" << clusterId << " {\n";
147 os.indent();
148 // Emit invisible anchor node from/to which arrows can be drawn.
149 Node anchorNode = emitNodeStmt(" ", kShapeNone);
150 os << attrStmt("label", quoteString(escapeString(std::move(label))))
151 << ";\n";
152 builder();
153 os.unindent();
154 os << "}\n";
155 return Node(anchorNode.id, clusterId);
158 /// Generate an attribute statement.
159 std::string attrStmt(const Twine &key, const Twine &value) {
160 return (key + " = " + value).str();
163 /// Emit an attribute list.
164 void emitAttrList(raw_ostream &os, const AttributeMap &map) {
165 os << "[";
166 interleaveComma(map, os, [&](const auto &it) {
167 os << this->attrStmt(it.first, it.second);
169 os << "]";
172 // Print an MLIR attribute to `os`. Large attributes are truncated.
173 void emitMlirAttr(raw_ostream &os, Attribute attr) {
174 // A value used to elide large container attribute.
175 int64_t largeAttrLimit = getLargeAttributeSizeLimit();
177 // Always emit splat attributes.
178 if (isa<SplatElementsAttr>(attr)) {
179 attr.print(os);
180 return;
183 // Elide "big" elements attributes.
184 auto elements = dyn_cast<ElementsAttr>(attr);
185 if (elements && elements.getNumElements() > largeAttrLimit) {
186 os << std::string(elements.getShapedType().getRank(), '[') << "..."
187 << std::string(elements.getShapedType().getRank(), ']') << " : "
188 << elements.getType();
189 return;
192 auto array = dyn_cast<ArrayAttr>(attr);
193 if (array && static_cast<int64_t>(array.size()) > largeAttrLimit) {
194 os << "[...]";
195 return;
198 // Print all other attributes.
199 std::string buf;
200 llvm::raw_string_ostream ss(buf);
201 attr.print(ss);
202 os << truncateString(buf);
205 /// Append an edge to the list of edges.
206 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
207 void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
208 AttributeMap attrs;
209 attrs["style"] = style.str();
210 // Do not label edges that start/end at a cluster boundary. Such edges are
211 // clipped at the boundary, but labels are not. This can lead to labels
212 // floating around without any edge next to them.
213 if (!n1.clusterId && !n2.clusterId)
214 attrs["label"] = quoteString(escapeString(std::move(label)));
215 // Use `ltail` and `lhead` to draw edges between clusters.
216 if (n1.clusterId)
217 attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
218 if (n2.clusterId)
219 attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
221 edges.push_back(strFromOs([&](raw_ostream &os) {
222 os << llvm::format("v%i -> v%i ", n1.id, n2.id);
223 emitAttrList(os, attrs);
224 }));
227 /// Emit a graph. The specified builder generates the body of the graph.
228 void emitGraph(function_ref<void()> builder) {
229 os << "digraph G {\n";
230 os.indent();
231 // Edges between clusters are allowed only in compound mode.
232 os << attrStmt("compound", "true") << ";\n";
233 builder();
234 os.unindent();
235 os << "}\n";
238 /// Emit a node statement.
239 Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
240 StringRef background = "") {
241 int nodeId = ++counter;
242 AttributeMap attrs;
243 attrs["label"] = quoteString(escapeString(std::move(label)));
244 attrs["shape"] = shape.str();
245 if (!background.empty()) {
246 attrs["style"] = "filled";
247 attrs["fillcolor"] = ("\"" + background + "\"").str();
249 os << llvm::format("v%i ", nodeId);
250 emitAttrList(os, attrs);
251 os << ";\n";
252 return Node(nodeId);
255 /// Generate a label for an operation.
256 std::string getLabel(Operation *op) {
257 return strFromOs([&](raw_ostream &os) {
258 // Print operation name and type.
259 os << op->getName();
260 if (printResultTypes) {
261 os << " : (";
262 std::string buf;
263 llvm::raw_string_ostream ss(buf);
264 interleaveComma(op->getResultTypes(), ss);
265 os << truncateString(buf) << ")";
268 // Print attributes.
269 if (printAttrs) {
270 os << "\n";
271 for (const NamedAttribute &attr : op->getAttrs()) {
272 os << '\n' << attr.getName().getValue() << ": ";
273 emitMlirAttr(os, attr.getValue());
279 /// Generate a label for a block argument.
280 std::string getLabel(BlockArgument arg) {
281 return "arg" + std::to_string(arg.getArgNumber());
284 /// Process a block. Emit a cluster and one node per block argument and
285 /// operation inside the cluster.
286 void processBlock(Block &block) {
287 emitClusterStmt([&]() {
288 for (BlockArgument &blockArg : block.getArguments())
289 valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
291 // Emit a node for each operation.
292 std::optional<Node> prevNode;
293 for (Operation &op : block) {
294 Node nextNode = processOperation(&op);
295 if (printControlFlowEdges && prevNode)
296 emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
297 kLineStyleControlFlow);
298 prevNode = nextNode;
303 /// Process an operation. If the operation has regions, emit a cluster.
304 /// Otherwise, emit a node.
305 Node processOperation(Operation *op) {
306 Node node;
307 if (op->getNumRegions() > 0) {
308 // Emit cluster for op with regions.
309 node = emitClusterStmt(
310 [&]() {
311 for (Region &region : op->getRegions())
312 processRegion(region);
314 getLabel(op));
315 } else {
316 node = emitNodeStmt(getLabel(op), kShapeNode,
317 backgroundColors[op->getName()].second);
320 // Insert data flow edges originating from each operand.
321 if (printDataFlowEdges) {
322 unsigned numOperands = op->getNumOperands();
323 for (unsigned i = 0; i < numOperands; i++)
324 dataFlowEdges.push_back({op->getOperand(i), node,
325 numOperands == 1 ? "" : std::to_string(i)});
328 for (Value result : op->getResults())
329 valueToNode[result] = node;
331 return node;
334 /// Process a region.
335 void processRegion(Region &region) {
336 for (Block &block : region.getBlocks())
337 processBlock(block);
340 /// Truncate long strings.
341 std::string truncateString(std::string str) {
342 if (str.length() <= maxLabelLen)
343 return str;
344 return str.substr(0, maxLabelLen) + "...";
347 /// Output stream to write DOT file to.
348 raw_indented_ostream os;
349 /// A list of edges. For simplicity, should be emitted after all nodes were
350 /// emitted.
351 std::vector<std::string> edges;
352 /// Mapping of SSA values to Graphviz nodes/clusters.
353 DenseMap<Value, Node> valueToNode;
354 /// Output for data flow edges is delayed until the end to handle cycles
355 std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
356 /// Counter for generating unique node/subgraph identifiers.
357 int counter = 0;
359 DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
362 } // namespace
364 std::unique_ptr<Pass> mlir::createPrintOpGraphPass(raw_ostream &os) {
365 return std::make_unique<PrintOpPass>(os);
368 /// Generate a CFG for a region and show it in a window.
369 static void llvmViewGraph(Region &region, const Twine &name) {
370 int fd;
371 std::string filename = llvm::createGraphFilename(name.str(), fd);
373 llvm::raw_fd_ostream os(fd, /*shouldClose=*/true);
374 if (fd == -1) {
375 llvm::errs() << "error opening file '" << filename << "' for writing\n";
376 return;
378 PrintOpPass pass(os);
379 pass.emitRegionCFG(region);
381 llvm::DisplayGraph(filename, /*wait=*/false, llvm::GraphProgram::DOT);
384 void mlir::Region::viewGraph(const Twine &regionName) {
385 llvmViewGraph(*this, regionName);
388 void mlir::Region::viewGraph() { viewGraph("region"); }