1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
12 // There is also a pass available to directly call dotty ('-view-callgraph').
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Analysis/BlockFrequencyInfo.h"
20 #include "llvm/Analysis/CallGraph.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/DOTGraphTraits.h"
27 #include "llvm/Support/GraphWriter.h"
32 template <class GraphType
> struct GraphTraits
;
35 // This option shows static (relative) call counts.
37 // Need to show real counts when profile data is available
38 static cl::opt
<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
40 cl::desc("Show heat colors in call-graph"));
43 ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden
,
44 cl::desc("Show edges labeled with weights"));
47 CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden
,
48 cl::desc("Show call-multigraph (do not remove parallel edges)"));
50 static cl::opt
<std::string
> CallGraphDotFilenamePrefix(
51 "callgraph-dot-filename-prefix", cl::Hidden
,
52 cl::desc("The prefix used for the CallGraph dot file names."));
56 class CallGraphDOTInfo
{
60 DenseMap
<const Function
*, uint64_t> Freq
;
64 std::function
<BlockFrequencyInfo
*(Function
&)> LookupBFI
;
66 CallGraphDOTInfo(Module
*M
, CallGraph
*CG
,
67 function_ref
<BlockFrequencyInfo
*(Function
&)> LookupBFI
)
68 : M(M
), CG(CG
), LookupBFI(LookupBFI
) {
71 for (Function
&F
: M
->getFunctionList()) {
72 uint64_t localSumFreq
= 0;
73 SmallSet
<Function
*, 16> Callers
;
74 for (User
*U
: F
.users())
76 Callers
.insert(cast
<Instruction
>(U
)->getFunction());
77 for (Function
*Caller
: Callers
)
78 localSumFreq
+= getNumOfCalls(*Caller
, F
);
79 if (localSumFreq
>= MaxFreq
)
80 MaxFreq
= localSumFreq
;
81 Freq
[&F
] = localSumFreq
;
84 removeParallelEdges();
87 Module
*getModule() const { return M
; }
89 CallGraph
*getCallGraph() const { return CG
; }
91 uint64_t getFreq(const Function
*F
) { return Freq
[F
]; }
93 uint64_t getMaxFreq() { return MaxFreq
; }
96 void removeParallelEdges() {
97 for (auto &I
: (*CG
)) {
98 CallGraphNode
*Node
= I
.second
.get();
100 bool FoundParallelEdge
= true;
101 while (FoundParallelEdge
) {
102 SmallSet
<Function
*, 16> Visited
;
103 FoundParallelEdge
= false;
104 for (auto CI
= Node
->begin(), CE
= Node
->end(); CI
!= CE
; CI
++) {
105 if (!(Visited
.insert(CI
->second
->getFunction())).second
) {
106 FoundParallelEdge
= true;
107 Node
->removeCallEdge(CI
);
117 struct GraphTraits
<CallGraphDOTInfo
*>
118 : public GraphTraits
<const CallGraphNode
*> {
119 static NodeRef
getEntryNode(CallGraphDOTInfo
*CGInfo
) {
120 // Start at the external node!
121 return CGInfo
->getCallGraph()->getExternalCallingNode();
124 typedef std::pair
<const Function
*const, std::unique_ptr
<CallGraphNode
>>
126 static const CallGraphNode
*CGGetValuePtr(const PairTy
&P
) {
127 return P
.second
.get();
130 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
131 typedef mapped_iterator
<CallGraph::const_iterator
, decltype(&CGGetValuePtr
)>
134 static nodes_iterator
nodes_begin(CallGraphDOTInfo
*CGInfo
) {
135 return nodes_iterator(CGInfo
->getCallGraph()->begin(), &CGGetValuePtr
);
137 static nodes_iterator
nodes_end(CallGraphDOTInfo
*CGInfo
) {
138 return nodes_iterator(CGInfo
->getCallGraph()->end(), &CGGetValuePtr
);
143 struct DOTGraphTraits
<CallGraphDOTInfo
*> : public DefaultDOTGraphTraits
{
145 DOTGraphTraits(bool isSimple
= false) : DefaultDOTGraphTraits(isSimple
) {}
147 static std::string
getGraphName(CallGraphDOTInfo
*CGInfo
) {
148 return "Call graph: " +
149 std::string(CGInfo
->getModule()->getModuleIdentifier());
152 static bool isNodeHidden(const CallGraphNode
*Node
,
153 const CallGraphDOTInfo
*CGInfo
) {
154 if (CallMultiGraph
|| Node
->getFunction())
159 std::string
getNodeLabel(const CallGraphNode
*Node
,
160 CallGraphDOTInfo
*CGInfo
) {
161 if (Node
== CGInfo
->getCallGraph()->getExternalCallingNode())
162 return "external caller";
163 if (Node
== CGInfo
->getCallGraph()->getCallsExternalNode())
164 return "external callee";
166 if (Function
*Func
= Node
->getFunction())
167 return std::string(Func
->getName());
168 return "external node";
170 static const CallGraphNode
*CGGetValuePtr(CallGraphNode::CallRecord P
) {
174 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
175 typedef mapped_iterator
<CallGraphNode::const_iterator
,
176 decltype(&CGGetValuePtr
)>
179 std::string
getEdgeAttributes(const CallGraphNode
*Node
, nodes_iterator I
,
180 CallGraphDOTInfo
*CGInfo
) {
184 Function
*Caller
= Node
->getFunction();
185 if (Caller
== nullptr || Caller
->isDeclaration())
188 Function
*Callee
= (*I
)->getFunction();
189 if (Callee
== nullptr)
192 uint64_t Counter
= getNumOfCalls(*Caller
, *Callee
);
194 1 + 2 * (double(Counter
) / CGInfo
->getMaxFreq());
195 std::string Attrs
= "label=\"" + std::to_string(Counter
) +
196 "\" penwidth=" + std::to_string(Width
);
200 std::string
getNodeAttributes(const CallGraphNode
*Node
,
201 CallGraphDOTInfo
*CGInfo
) {
202 Function
*F
= Node
->getFunction();
206 if (ShowHeatColors
) {
207 uint64_t freq
= CGInfo
->getFreq(F
);
208 std::string color
= getHeatColor(freq
, CGInfo
->getMaxFreq());
209 std::string edgeColor
= (freq
<= (CGInfo
->getMaxFreq() / 2))
212 attrs
= "color=\"" + edgeColor
+ "ff\", style=filled, fillcolor=\"" +
222 void doCallGraphDOTPrinting(
223 Module
&M
, function_ref
<BlockFrequencyInfo
*(Function
&)> LookupBFI
) {
224 std::string Filename
;
225 if (!CallGraphDotFilenamePrefix
.empty())
226 Filename
= (CallGraphDotFilenamePrefix
+ ".callgraph.dot");
228 Filename
= (std::string(M
.getModuleIdentifier()) + ".callgraph.dot");
229 errs() << "Writing '" << Filename
<< "'...";
232 raw_fd_ostream
File(Filename
, EC
, sys::fs::OF_Text
);
235 CallGraphDOTInfo
CFGInfo(&M
, &CG
, LookupBFI
);
238 WriteGraph(File
, &CFGInfo
);
240 errs() << " error opening file for writing!";
244 void viewCallGraph(Module
&M
,
245 function_ref
<BlockFrequencyInfo
*(Function
&)> LookupBFI
) {
247 CallGraphDOTInfo
CFGInfo(&M
, &CG
, LookupBFI
);
250 DOTGraphTraits
<CallGraphDOTInfo
*>::getGraphName(&CFGInfo
);
251 ViewGraph(&CFGInfo
, "callgraph", true, Title
);
256 PreservedAnalyses
CallGraphDOTPrinterPass::run(Module
&M
,
257 ModuleAnalysisManager
&AM
) {
258 FunctionAnalysisManager
&FAM
=
259 AM
.getResult
<FunctionAnalysisManagerModuleProxy
>(M
).getManager();
261 auto LookupBFI
= [&FAM
](Function
&F
) {
262 return &FAM
.getResult
<BlockFrequencyAnalysis
>(F
);
265 doCallGraphDOTPrinting(M
, LookupBFI
);
267 return PreservedAnalyses::all();
270 PreservedAnalyses
CallGraphViewerPass::run(Module
&M
,
271 ModuleAnalysisManager
&AM
) {
273 FunctionAnalysisManager
&FAM
=
274 AM
.getResult
<FunctionAnalysisManagerModuleProxy
>(M
).getManager();
276 auto LookupBFI
= [&FAM
](Function
&F
) {
277 return &FAM
.getResult
<BlockFrequencyAnalysis
>(F
);
280 viewCallGraph(M
, LookupBFI
);
282 return PreservedAnalyses::all();
288 class CallGraphViewer
: public ModulePass
{
291 CallGraphViewer() : ModulePass(ID
) {}
293 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
294 bool runOnModule(Module
&M
) override
;
297 void CallGraphViewer::getAnalysisUsage(AnalysisUsage
&AU
) const {
298 ModulePass::getAnalysisUsage(AU
);
299 AU
.addRequired
<BlockFrequencyInfoWrapperPass
>();
300 AU
.setPreservesAll();
303 bool CallGraphViewer::runOnModule(Module
&M
) {
304 auto LookupBFI
= [this](Function
&F
) {
305 return &this->getAnalysis
<BlockFrequencyInfoWrapperPass
>(F
).getBFI();
308 viewCallGraph(M
, LookupBFI
);
315 class CallGraphDOTPrinter
: public ModulePass
{
318 CallGraphDOTPrinter() : ModulePass(ID
) {}
320 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
321 bool runOnModule(Module
&M
) override
;
324 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage
&AU
) const {
325 ModulePass::getAnalysisUsage(AU
);
326 AU
.addRequired
<BlockFrequencyInfoWrapperPass
>();
327 AU
.setPreservesAll();
330 bool CallGraphDOTPrinter::runOnModule(Module
&M
) {
331 auto LookupBFI
= [this](Function
&F
) {
332 return &this->getAnalysis
<BlockFrequencyInfoWrapperPass
>(F
).getBFI();
335 doCallGraphDOTPrinting(M
, LookupBFI
);
340 } // end anonymous namespace
342 char CallGraphViewer::ID
= 0;
343 INITIALIZE_PASS(CallGraphViewer
, "view-callgraph", "View call graph", false,
346 char CallGraphDOTPrinter::ID
= 0;
347 INITIALIZE_PASS(CallGraphDOTPrinter
, "dot-callgraph",
348 "Print call graph to 'dot' file", false, false)
350 // Create methods available outside of this file, to use them
351 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
352 // the link time optimization.
354 ModulePass
*llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
356 ModulePass
*llvm::createCallGraphDOTPrinterPass() {
357 return new CallGraphDOTPrinter();