1 //===- ViewOpGraph.cpp - View/write op graphviz graphs --------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/ViewOpGraph.h"
11 #include "mlir/Analysis/TopologicalSortUtils.h"
12 #include "mlir/IR/Block.h"
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Support/IndentedOstream.h"
17 #include "llvm/Support/Format.h"
18 #include "llvm/Support/GraphWriter.h"
24 #define GEN_PASS_DEF_VIEWOPGRAPH
25 #include "mlir/Transforms/Passes.h.inc"
30 static const StringRef kLineStyleControlFlow
= "dashed";
31 static const StringRef kLineStyleDataFlow
= "solid";
32 static const StringRef kShapeNode
= "ellipse";
33 static const StringRef kShapeNone
= "plain";
35 /// Return the size limits for eliding large attributes.
36 static int64_t getLargeAttributeSizeLimit() {
37 // Use the default from the printer flags if possible.
38 if (std::optional
<int64_t> limit
=
39 OpPrintingFlags().getLargeElementsAttrLimit())
44 /// Return all values printed onto a stream as a string.
45 static std::string
strFromOs(function_ref
<void(raw_ostream
&)> func
) {
47 llvm::raw_string_ostream
os(buf
);
52 /// Escape special characters such as '\n' and quotation marks.
53 static std::string
escapeString(std::string str
) {
54 return strFromOs([&](raw_ostream
&os
) { os
.write_escaped(str
); });
57 /// Put quotation marks around a given string.
58 static std::string
quoteString(const std::string
&str
) {
59 return "\"" + str
+ "\"";
62 using AttributeMap
= std::map
<std::string
, std::string
>;
66 /// This struct represents a node in the DOT language. Each node has an
67 /// identifier and an optional identifier for the cluster (subgraph) that
68 /// contains the node.
69 /// Note: In the DOT language, edges can be drawn only from nodes to nodes, but
70 /// not between clusters. However, edges can be clipped to the boundary of a
71 /// cluster with `lhead` and `ltail` attributes. Therefore, when creating a new
72 /// cluster, an invisible "anchor" node is created.
75 Node(int id
= 0, std::optional
<int> clusterId
= std::nullopt
)
76 : id(id
), clusterId(clusterId
) {}
79 std::optional
<int> clusterId
;
82 /// This pass generates a Graphviz dataflow visualization of an MLIR operation.
83 /// Note: See https://www.graphviz.org/doc/info/lang.html for more information
84 /// about the Graphviz DOT language.
85 class PrintOpPass
: public impl::ViewOpGraphBase
<PrintOpPass
> {
87 PrintOpPass(raw_ostream
&os
) : os(os
) {}
88 PrintOpPass(const PrintOpPass
&o
) : PrintOpPass(o
.os
.getOStream()) {}
90 void runOnOperation() override
{
91 initColorMapping(*getOperation());
93 processOperation(getOperation());
96 markAllAnalysesPreserved();
99 /// Create a CFG graph for a region. Used in `Region::viewGraph`.
100 void emitRegionCFG(Region
®ion
) {
101 printControlFlowEdges
= true;
102 printDataFlowEdges
= false;
103 initColorMapping(region
);
104 emitGraph([&]() { processRegion(region
); });
108 /// Generate a color mapping that will color every operation with the same
109 /// name the same way. It'll interpolate the hue in the HSV color-space,
110 /// attempting to keep the contrast suitable for black text.
111 template <typename T
>
112 void initColorMapping(T
&irEntity
) {
113 backgroundColors
.clear();
114 SmallVector
<Operation
*> ops
;
115 irEntity
.walk([&](Operation
*op
) {
116 auto &entry
= backgroundColors
[op
->getName()];
117 if (entry
.first
== 0)
121 for (auto indexedOps
: llvm::enumerate(ops
)) {
122 double hue
= ((double)indexedOps
.index()) / ops
.size();
123 backgroundColors
[indexedOps
.value()->getName()].second
=
124 std::to_string(hue
) + " 1.0 1.0";
128 /// Emit all edges. This function should be called after all nodes have been
130 void emitAllEdgeStmts() {
131 if (printDataFlowEdges
) {
132 for (const auto &[value
, node
, label
] : dataFlowEdges
) {
133 emitEdgeStmt(valueToNode
[value
], node
, label
, kLineStyleDataFlow
);
137 for (const std::string
&edge
: edges
)
142 /// Emit a cluster (subgraph). The specified builder generates the body of the
143 /// cluster. Return the anchor node of the cluster.
144 Node
emitClusterStmt(function_ref
<void()> builder
, std::string label
= "") {
145 int clusterId
= ++counter
;
146 os
<< "subgraph cluster_" << clusterId
<< " {\n";
148 // Emit invisible anchor node from/to which arrows can be drawn.
149 Node anchorNode
= emitNodeStmt(" ", kShapeNone
);
150 os
<< attrStmt("label", quoteString(escapeString(std::move(label
))))
155 return Node(anchorNode
.id
, clusterId
);
158 /// Generate an attribute statement.
159 std::string
attrStmt(const Twine
&key
, const Twine
&value
) {
160 return (key
+ " = " + value
).str();
163 /// Emit an attribute list.
164 void emitAttrList(raw_ostream
&os
, const AttributeMap
&map
) {
166 interleaveComma(map
, os
, [&](const auto &it
) {
167 os
<< this->attrStmt(it
.first
, it
.second
);
172 // Print an MLIR attribute to `os`. Large attributes are truncated.
173 void emitMlirAttr(raw_ostream
&os
, Attribute attr
) {
174 // A value used to elide large container attribute.
175 int64_t largeAttrLimit
= getLargeAttributeSizeLimit();
177 // Always emit splat attributes.
178 if (isa
<SplatElementsAttr
>(attr
)) {
183 // Elide "big" elements attributes.
184 auto elements
= dyn_cast
<ElementsAttr
>(attr
);
185 if (elements
&& elements
.getNumElements() > largeAttrLimit
) {
186 os
<< std::string(elements
.getShapedType().getRank(), '[') << "..."
187 << std::string(elements
.getShapedType().getRank(), ']') << " : "
188 << elements
.getType();
192 auto array
= dyn_cast
<ArrayAttr
>(attr
);
193 if (array
&& static_cast<int64_t>(array
.size()) > largeAttrLimit
) {
198 // Print all other attributes.
200 llvm::raw_string_ostream
ss(buf
);
202 os
<< truncateString(buf
);
205 /// Append an edge to the list of edges.
206 /// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
207 void emitEdgeStmt(Node n1
, Node n2
, std::string label
, StringRef style
) {
209 attrs
["style"] = style
.str();
210 // Do not label edges that start/end at a cluster boundary. Such edges are
211 // clipped at the boundary, but labels are not. This can lead to labels
212 // floating around without any edge next to them.
213 if (!n1
.clusterId
&& !n2
.clusterId
)
214 attrs
["label"] = quoteString(escapeString(std::move(label
)));
215 // Use `ltail` and `lhead` to draw edges between clusters.
217 attrs
["ltail"] = "cluster_" + std::to_string(*n1
.clusterId
);
219 attrs
["lhead"] = "cluster_" + std::to_string(*n2
.clusterId
);
221 edges
.push_back(strFromOs([&](raw_ostream
&os
) {
222 os
<< llvm::format("v%i -> v%i ", n1
.id
, n2
.id
);
223 emitAttrList(os
, attrs
);
227 /// Emit a graph. The specified builder generates the body of the graph.
228 void emitGraph(function_ref
<void()> builder
) {
229 os
<< "digraph G {\n";
231 // Edges between clusters are allowed only in compound mode.
232 os
<< attrStmt("compound", "true") << ";\n";
238 /// Emit a node statement.
239 Node
emitNodeStmt(std::string label
, StringRef shape
= kShapeNode
,
240 StringRef background
= "") {
241 int nodeId
= ++counter
;
243 attrs
["label"] = quoteString(escapeString(std::move(label
)));
244 attrs
["shape"] = shape
.str();
245 if (!background
.empty()) {
246 attrs
["style"] = "filled";
247 attrs
["fillcolor"] = ("\"" + background
+ "\"").str();
249 os
<< llvm::format("v%i ", nodeId
);
250 emitAttrList(os
, attrs
);
255 /// Generate a label for an operation.
256 std::string
getLabel(Operation
*op
) {
257 return strFromOs([&](raw_ostream
&os
) {
258 // Print operation name and type.
260 if (printResultTypes
) {
263 llvm::raw_string_ostream
ss(buf
);
264 interleaveComma(op
->getResultTypes(), ss
);
265 os
<< truncateString(buf
) << ")";
271 for (const NamedAttribute
&attr
: op
->getAttrs()) {
272 os
<< '\n' << attr
.getName().getValue() << ": ";
273 emitMlirAttr(os
, attr
.getValue());
279 /// Generate a label for a block argument.
280 std::string
getLabel(BlockArgument arg
) {
281 return "arg" + std::to_string(arg
.getArgNumber());
284 /// Process a block. Emit a cluster and one node per block argument and
285 /// operation inside the cluster.
286 void processBlock(Block
&block
) {
287 emitClusterStmt([&]() {
288 for (BlockArgument
&blockArg
: block
.getArguments())
289 valueToNode
[blockArg
] = emitNodeStmt(getLabel(blockArg
));
291 // Emit a node for each operation.
292 std::optional
<Node
> prevNode
;
293 for (Operation
&op
: block
) {
294 Node nextNode
= processOperation(&op
);
295 if (printControlFlowEdges
&& prevNode
)
296 emitEdgeStmt(*prevNode
, nextNode
, /*label=*/"",
297 kLineStyleControlFlow
);
303 /// Process an operation. If the operation has regions, emit a cluster.
304 /// Otherwise, emit a node.
305 Node
processOperation(Operation
*op
) {
307 if (op
->getNumRegions() > 0) {
308 // Emit cluster for op with regions.
309 node
= emitClusterStmt(
311 for (Region
®ion
: op
->getRegions())
312 processRegion(region
);
316 node
= emitNodeStmt(getLabel(op
), kShapeNode
,
317 backgroundColors
[op
->getName()].second
);
320 // Insert data flow edges originating from each operand.
321 if (printDataFlowEdges
) {
322 unsigned numOperands
= op
->getNumOperands();
323 for (unsigned i
= 0; i
< numOperands
; i
++)
324 dataFlowEdges
.push_back({op
->getOperand(i
), node
,
325 numOperands
== 1 ? "" : std::to_string(i
)});
328 for (Value result
: op
->getResults())
329 valueToNode
[result
] = node
;
334 /// Process a region.
335 void processRegion(Region
®ion
) {
336 for (Block
&block
: region
.getBlocks())
340 /// Truncate long strings.
341 std::string
truncateString(std::string str
) {
342 if (str
.length() <= maxLabelLen
)
344 return str
.substr(0, maxLabelLen
) + "...";
347 /// Output stream to write DOT file to.
348 raw_indented_ostream os
;
349 /// A list of edges. For simplicity, should be emitted after all nodes were
351 std::vector
<std::string
> edges
;
352 /// Mapping of SSA values to Graphviz nodes/clusters.
353 DenseMap
<Value
, Node
> valueToNode
;
354 /// Output for data flow edges is delayed until the end to handle cycles
355 std::vector
<std::tuple
<Value
, Node
, std::string
>> dataFlowEdges
;
356 /// Counter for generating unique node/subgraph identifiers.
359 DenseMap
<OperationName
, std::pair
<int, std::string
>> backgroundColors
;
364 std::unique_ptr
<Pass
> mlir::createPrintOpGraphPass(raw_ostream
&os
) {
365 return std::make_unique
<PrintOpPass
>(os
);
368 /// Generate a CFG for a region and show it in a window.
369 static void llvmViewGraph(Region
®ion
, const Twine
&name
) {
371 std::string filename
= llvm::createGraphFilename(name
.str(), fd
);
373 llvm::raw_fd_ostream
os(fd
, /*shouldClose=*/true);
375 llvm::errs() << "error opening file '" << filename
<< "' for writing\n";
378 PrintOpPass
pass(os
);
379 pass
.emitRegionCFG(region
);
381 llvm::DisplayGraph(filename
, /*wait=*/false, llvm::GraphProgram::DOT
);
384 void mlir::Region::viewGraph(const Twine
®ionName
) {
385 llvmViewGraph(*this, regionName
);
388 void mlir::Region::viewGraph() { viewGraph("region"); }