1 //===- CallPrinter.cpp - DOT printer for call graph -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot
10 // containing the call graph of a module.
12 // There is also a pass available to directly call dotty ('-view-callgraph').
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Analysis/CallPrinter.h"
17 #include "llvm/Analysis/BlockFrequencyInfo.h"
18 #include "llvm/Analysis/BranchProbabilityInfo.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/DOTGraphTraitsPass.h"
21 #include "llvm/Analysis/HeatUtils.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/SmallSet.h"
29 // This option shows static (relative) call counts.
31 // Need to show real counts when profile data is available
32 static cl::opt
<bool> ShowHeatColors("callgraph-heat-colors", cl::init(false),
34 cl::desc("Show heat colors in call-graph"));
37 ShowEdgeWeight("callgraph-show-weights", cl::init(false), cl::Hidden
,
38 cl::desc("Show edges labeled with weights"));
41 CallMultiGraph("callgraph-multigraph", cl::init(false), cl::Hidden
,
42 cl::desc("Show call-multigraph (do not remove parallel edges)"));
44 static cl::opt
<std::string
> CallGraphDotFilenamePrefix(
45 "callgraph-dot-filename-prefix", cl::Hidden
,
46 cl::desc("The prefix used for the CallGraph dot file names."));
50 class CallGraphDOTInfo
{
54 DenseMap
<const Function
*, uint64_t> Freq
;
58 std::function
<BlockFrequencyInfo
*(Function
&)> LookupBFI
;
60 CallGraphDOTInfo(Module
*M
, CallGraph
*CG
,
61 function_ref
<BlockFrequencyInfo
*(Function
&)> LookupBFI
)
62 : M(M
), CG(CG
), LookupBFI(LookupBFI
) {
65 for (Function
&F
: M
->getFunctionList()) {
66 uint64_t localSumFreq
= 0;
67 SmallSet
<Function
*, 16> Callers
;
68 for (User
*U
: F
.users())
70 Callers
.insert(cast
<Instruction
>(U
)->getFunction());
71 for (Function
*Caller
: Callers
)
72 localSumFreq
+= getNumOfCalls(*Caller
, F
);
73 if (localSumFreq
>= MaxFreq
)
74 MaxFreq
= localSumFreq
;
75 Freq
[&F
] = localSumFreq
;
78 removeParallelEdges();
81 Module
*getModule() const { return M
; }
83 CallGraph
*getCallGraph() const { return CG
; }
85 uint64_t getFreq(const Function
*F
) { return Freq
[F
]; }
87 uint64_t getMaxFreq() { return MaxFreq
; }
90 void removeParallelEdges() {
91 for (auto &I
: (*CG
)) {
92 CallGraphNode
*Node
= I
.second
.get();
94 bool FoundParallelEdge
= true;
95 while (FoundParallelEdge
) {
96 SmallSet
<Function
*, 16> Visited
;
97 FoundParallelEdge
= false;
98 for (auto CI
= Node
->begin(), CE
= Node
->end(); CI
!= CE
; CI
++) {
99 if (!(Visited
.insert(CI
->second
->getFunction())).second
) {
100 FoundParallelEdge
= true;
101 Node
->removeCallEdge(CI
);
111 struct GraphTraits
<CallGraphDOTInfo
*>
112 : public GraphTraits
<const CallGraphNode
*> {
113 static NodeRef
getEntryNode(CallGraphDOTInfo
*CGInfo
) {
114 // Start at the external node!
115 return CGInfo
->getCallGraph()->getExternalCallingNode();
118 typedef std::pair
<const Function
*const, std::unique_ptr
<CallGraphNode
>>
120 static const CallGraphNode
*CGGetValuePtr(const PairTy
&P
) {
121 return P
.second
.get();
124 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
125 typedef mapped_iterator
<CallGraph::const_iterator
, decltype(&CGGetValuePtr
)>
128 static nodes_iterator
nodes_begin(CallGraphDOTInfo
*CGInfo
) {
129 return nodes_iterator(CGInfo
->getCallGraph()->begin(), &CGGetValuePtr
);
131 static nodes_iterator
nodes_end(CallGraphDOTInfo
*CGInfo
) {
132 return nodes_iterator(CGInfo
->getCallGraph()->end(), &CGGetValuePtr
);
137 struct DOTGraphTraits
<CallGraphDOTInfo
*> : public DefaultDOTGraphTraits
{
139 DOTGraphTraits(bool isSimple
= false) : DefaultDOTGraphTraits(isSimple
) {}
141 static std::string
getGraphName(CallGraphDOTInfo
*CGInfo
) {
142 return "Call graph: " +
143 std::string(CGInfo
->getModule()->getModuleIdentifier());
146 static bool isNodeHidden(const CallGraphNode
*Node
,
147 const CallGraphDOTInfo
*CGInfo
) {
148 if (CallMultiGraph
|| Node
->getFunction())
153 std::string
getNodeLabel(const CallGraphNode
*Node
,
154 CallGraphDOTInfo
*CGInfo
) {
155 if (Node
== CGInfo
->getCallGraph()->getExternalCallingNode())
156 return "external caller";
157 if (Node
== CGInfo
->getCallGraph()->getCallsExternalNode())
158 return "external callee";
160 if (Function
*Func
= Node
->getFunction())
161 return std::string(Func
->getName());
162 return "external node";
164 static const CallGraphNode
*CGGetValuePtr(CallGraphNode::CallRecord P
) {
168 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
169 typedef mapped_iterator
<CallGraphNode::const_iterator
,
170 decltype(&CGGetValuePtr
)>
173 std::string
getEdgeAttributes(const CallGraphNode
*Node
, nodes_iterator I
,
174 CallGraphDOTInfo
*CGInfo
) {
178 Function
*Caller
= Node
->getFunction();
179 if (Caller
== nullptr || Caller
->isDeclaration())
182 Function
*Callee
= (*I
)->getFunction();
183 if (Callee
== nullptr)
186 uint64_t Counter
= getNumOfCalls(*Caller
, *Callee
);
188 1 + 2 * (double(Counter
) / CGInfo
->getMaxFreq());
189 std::string Attrs
= "label=\"" + std::to_string(Counter
) +
190 "\" penwidth=" + std::to_string(Width
);
194 std::string
getNodeAttributes(const CallGraphNode
*Node
,
195 CallGraphDOTInfo
*CGInfo
) {
196 Function
*F
= Node
->getFunction();
200 if (ShowHeatColors
) {
201 uint64_t freq
= CGInfo
->getFreq(F
);
202 std::string color
= getHeatColor(freq
, CGInfo
->getMaxFreq());
203 std::string edgeColor
= (freq
<= (CGInfo
->getMaxFreq() / 2))
206 attrs
= "color=\"" + edgeColor
+ "ff\", style=filled, fillcolor=\"" +
213 } // end llvm namespace
217 class CallGraphViewer
: public ModulePass
{
220 CallGraphViewer() : ModulePass(ID
) {}
222 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
223 bool runOnModule(Module
&M
) override
;
226 void CallGraphViewer::getAnalysisUsage(AnalysisUsage
&AU
) const {
227 ModulePass::getAnalysisUsage(AU
);
228 AU
.addRequired
<BlockFrequencyInfoWrapperPass
>();
229 AU
.setPreservesAll();
232 bool CallGraphViewer::runOnModule(Module
&M
) {
233 auto LookupBFI
= [this](Function
&F
) {
234 return &this->getAnalysis
<BlockFrequencyInfoWrapperPass
>(F
).getBFI();
238 CallGraphDOTInfo
CFGInfo(&M
, &CG
, LookupBFI
);
241 DOTGraphTraits
<CallGraphDOTInfo
*>::getGraphName(&CFGInfo
);
242 ViewGraph(&CFGInfo
, "callgraph", true, Title
);
249 class CallGraphDOTPrinter
: public ModulePass
{
252 CallGraphDOTPrinter() : ModulePass(ID
) {}
254 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
255 bool runOnModule(Module
&M
) override
;
258 void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage
&AU
) const {
259 ModulePass::getAnalysisUsage(AU
);
260 AU
.addRequired
<BlockFrequencyInfoWrapperPass
>();
261 AU
.setPreservesAll();
264 bool CallGraphDOTPrinter::runOnModule(Module
&M
) {
265 auto LookupBFI
= [this](Function
&F
) {
266 return &this->getAnalysis
<BlockFrequencyInfoWrapperPass
>(F
).getBFI();
269 std::string Filename
;
270 if (!CallGraphDotFilenamePrefix
.empty())
271 Filename
= (CallGraphDotFilenamePrefix
+ ".callgraph.dot");
273 Filename
= (std::string(M
.getModuleIdentifier()) + ".callgraph.dot");
274 errs() << "Writing '" << Filename
<< "'...";
277 raw_fd_ostream
File(Filename
, EC
, sys::fs::OF_Text
);
280 CallGraphDOTInfo
CFGInfo(&M
, &CG
, LookupBFI
);
283 WriteGraph(File
, &CFGInfo
);
285 errs() << " error opening file for writing!";
291 } // end anonymous namespace
293 char CallGraphViewer::ID
= 0;
294 INITIALIZE_PASS(CallGraphViewer
, "view-callgraph", "View call graph", false,
297 char CallGraphDOTPrinter::ID
= 0;
298 INITIALIZE_PASS(CallGraphDOTPrinter
, "dot-callgraph",
299 "Print call graph to 'dot' file", false, false)
301 // Create methods available outside of this file, to use them
302 // "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by
303 // the link time optimization.
305 ModulePass
*llvm::createCallGraphViewerPass() { return new CallGraphViewer(); }
307 ModulePass
*llvm::createCallGraphDOTPrinterPass() {
308 return new CallGraphDOTPrinter();