1 //===- bolt/Passes/JTFootprintReduction.cpp -------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements JTFootprintReduction class.
11 //===----------------------------------------------------------------------===//
13 #include "bolt/Passes/JTFootprintReduction.h"
14 #include "bolt/Passes/BinaryFunctionCallGraph.h"
15 #include "bolt/Passes/DataflowInfoManager.h"
16 #include "llvm/Support/CommandLine.h"
18 #define DEBUG_TYPE "JT"
25 extern cl::OptionCategory BoltOptCategory
;
27 extern cl::opt
<unsigned> Verbosity
;
29 extern cl::opt
<JumpTableSupportLevel
> JumpTables
;
31 static cl::opt
<bool> JTFootprintOnlyPIC(
32 "jt-footprint-optimize-for-icache",
33 cl::desc("with jt-footprint-reduction, only process PIC jumptables and turn"
34 " off other transformations that increase code size"),
35 cl::init(false), cl::ZeroOrMore
, cl::cat(BoltOptCategory
));
42 void JTFootprintReduction::checkOpportunities(BinaryFunction
&Function
,
43 DataflowInfoManager
&Info
) {
44 BinaryContext
&BC
= Function
.getBinaryContext();
45 std::map
<JumpTable
*, uint64_t> AllJTs
;
47 for (BinaryBasicBlock
&BB
: Function
) {
48 for (MCInst
&Inst
: BB
) {
49 JumpTable
*JumpTable
= Function
.getJumpTable(Inst
);
53 AllJTs
[JumpTable
] += BB
.getKnownExecutionCount();
56 if (BlacklistedJTs
.count(JumpTable
)) {
62 // Try a standard indirect jump matcher
63 std::unique_ptr
<MCPlusBuilder::MCInstMatcher
> IndJmpMatcher
=
64 BC
.MIB
->matchIndJmp(BC
.MIB
->matchAnyOperand(),
65 BC
.MIB
->matchImm(Scale
), BC
.MIB
->matchReg(),
66 BC
.MIB
->matchAnyOperand());
67 if (!opts::JTFootprintOnlyPIC
&&
68 IndJmpMatcher
->match(*BC
.MRI
, *BC
.MIB
,
69 MutableArrayRef
<MCInst
>(&*BB
.begin(), &Inst
+ 1),
72 if (Info
.getLivenessAnalysis().scavengeRegAfter(&Inst
))
74 BlacklistedJTs
.insert(JumpTable
);
80 // Try a PIC matcher. The pattern we are looking for is a PIC JT ind jmp:
83 // leaq DATAat0x402450(%rip), %r11
84 // movslq (%r11,%rdx,4), %rcx
86 // jmpq *%rcx # JUMPTABLE @0x402450
90 std::unique_ptr
<MCPlusBuilder::MCInstMatcher
> PICIndJmpMatcher
=
91 BC
.MIB
->matchIndJmp(BC
.MIB
->matchAdd(
92 BC
.MIB
->matchReg(BaseReg1
),
93 BC
.MIB
->matchLoad(BC
.MIB
->matchReg(BaseReg2
),
94 BC
.MIB
->matchImm(Scale
), BC
.MIB
->matchReg(),
95 BC
.MIB
->matchImm(Offset
))));
96 std::unique_ptr
<MCPlusBuilder::MCInstMatcher
> PICBaseAddrMatcher
=
98 BC
.MIB
->matchAdd(BC
.MIB
->matchLoadAddr(BC
.MIB
->matchSymbol()),
99 BC
.MIB
->matchAnyOperand()));
100 if (!PICIndJmpMatcher
->match(
102 MutableArrayRef
<MCInst
>(&*BB
.begin(), &Inst
+ 1), -1) ||
103 Scale
!= 4 || BaseReg1
!= BaseReg2
|| Offset
!= 0 ||
104 !PICBaseAddrMatcher
->match(
106 MutableArrayRef
<MCInst
>(&*BB
.begin(), &Inst
+ 1), -1)) {
107 BlacklistedJTs
.insert(JumpTable
);
116 for (const auto &JTFreq
: AllJTs
) {
117 JumpTable
*JT
= JTFreq
.first
;
118 uint64_t CurScore
= JTFreq
.second
;
119 TotalJTScore
+= CurScore
;
120 if (!BlacklistedJTs
.count(JT
)) {
121 OptimizedScore
+= CurScore
;
122 if (JT
->EntrySize
== 8)
123 BytesSaved
+= JT
->getSize() >> 1;
126 TotalJTs
+= AllJTs
.size();
127 TotalJTsDenied
+= BlacklistedJTs
.size();
130 bool JTFootprintReduction::tryOptimizeNonPIC(
131 BinaryContext
&BC
, BinaryBasicBlock
&BB
, BinaryBasicBlock::iterator Inst
,
132 uint64_t JTAddr
, JumpTable
*JumpTable
, DataflowInfoManager
&Info
) {
133 if (opts::JTFootprintOnlyPIC
)
140 std::unique_ptr
<MCPlusBuilder::MCInstMatcher
> IndJmpMatcher
=
141 BC
.MIB
->matchIndJmp(BC
.MIB
->matchAnyOperand(Base
),
142 BC
.MIB
->matchImm(Scale
), BC
.MIB
->matchReg(Index
),
143 BC
.MIB
->matchAnyOperand(Offset
));
144 if (!IndJmpMatcher
->match(*BC
.MRI
, *BC
.MIB
,
145 MutableArrayRef
<MCInst
>(&*BB
.begin(), &*Inst
+ 1),
149 assert(Scale
== 8 && "Wrong scale");
152 IndJmpMatcher
->annotate(*BC
.MIB
, "DeleteMe");
154 LivenessAnalysis
&LA
= Info
.getLivenessAnalysis();
155 MCPhysReg Reg
= LA
.scavengeRegAfter(&*Inst
);
156 assert(Reg
!= 0 && "Register scavenger failed!");
157 MCOperand RegOp
= MCOperand::createReg(Reg
);
158 SmallVector
<MCInst
, 4> NewFrag
;
160 BC
.MIB
->createIJmp32Frag(NewFrag
, Base
, MCOperand::createImm(Scale
),
161 MCOperand::createReg(Index
), Offset
, RegOp
);
162 BC
.MIB
->setJumpTable(NewFrag
.back(), JTAddr
, Index
);
164 JumpTable
->OutputEntrySize
= 4;
166 BB
.replaceInstruction(Inst
, NewFrag
.begin(), NewFrag
.end());
170 bool JTFootprintReduction::tryOptimizePIC(BinaryContext
&BC
,
171 BinaryBasicBlock
&BB
,
172 BinaryBasicBlock::iterator Inst
,
173 uint64_t JTAddr
, JumpTable
*JumpTable
,
174 DataflowInfoManager
&Info
) {
179 MCOperand JumpTableRef
;
180 std::unique_ptr
<MCPlusBuilder::MCInstMatcher
> PICIndJmpMatcher
=
181 BC
.MIB
->matchIndJmp(BC
.MIB
->matchAdd(
182 BC
.MIB
->matchLoadAddr(BC
.MIB
->matchAnyOperand(JumpTableRef
)),
183 BC
.MIB
->matchLoad(BC
.MIB
->matchReg(BaseReg
), BC
.MIB
->matchImm(Scale
),
184 BC
.MIB
->matchReg(Index
),
185 BC
.MIB
->matchAnyOperand())));
186 if (!PICIndJmpMatcher
->match(
187 *BC
.MRI
, *BC
.MIB
, MutableArrayRef
<MCInst
>(&*BB
.begin(), &*Inst
+ 1),
191 assert(Scale
== 4 && "Wrong scale");
193 PICIndJmpMatcher
->annotate(*BC
.MIB
, "DeleteMe");
195 MCOperand RegOp
= MCOperand::createReg(BaseReg
);
196 SmallVector
<MCInst
, 4> NewFrag
;
198 BC
.MIB
->createIJmp32Frag(NewFrag
, MCOperand::createReg(0),
199 MCOperand::createImm(Scale
),
200 MCOperand::createReg(Index
), JumpTableRef
, RegOp
);
201 BC
.MIB
->setJumpTable(NewFrag
.back(), JTAddr
, Index
);
203 JumpTable
->OutputEntrySize
= 4;
205 JumpTable
->Type
= JumpTable::JTT_NORMAL
;
207 BB
.replaceInstruction(Inst
, NewFrag
.begin(), NewFrag
.end());
211 void JTFootprintReduction::optimizeFunction(BinaryFunction
&Function
,
212 DataflowInfoManager
&Info
) {
213 BinaryContext
&BC
= Function
.getBinaryContext();
214 for (BinaryBasicBlock
&BB
: Function
) {
215 if (!BB
.getNumNonPseudos())
218 auto IndJmpRI
= BB
.getLastNonPseudo();
219 auto IndJmp
= std::prev(IndJmpRI
.base());
220 const uint64_t JTAddr
= BC
.MIB
->getJumpTable(*IndJmp
);
225 JumpTable
*JumpTable
= Function
.getJumpTable(*IndJmp
);
226 if (BlacklistedJTs
.count(JumpTable
))
229 if (tryOptimizeNonPIC(BC
, BB
, IndJmp
, JTAddr
, JumpTable
, Info
) ||
230 tryOptimizePIC(BC
, BB
, IndJmp
, JTAddr
, JumpTable
, Info
)) {
231 Modified
.insert(&Function
);
235 llvm_unreachable("Should either optimize PIC or NonPIC successfuly");
238 if (!Modified
.count(&Function
))
241 for (BinaryBasicBlock
&BB
: Function
)
242 for (auto I
= BB
.begin(); I
!= BB
.end();)
243 if (BC
.MIB
->hasAnnotation(*I
, "DeleteMe"))
244 I
= BB
.eraseInstruction(I
);
249 void JTFootprintReduction::runOnFunctions(BinaryContext
&BC
) {
250 if (opts::JumpTables
== JTS_BASIC
&& BC
.HasRelocations
)
253 std::unique_ptr
<RegAnalysis
> RA
;
254 std::unique_ptr
<BinaryFunctionCallGraph
> CG
;
255 if (!opts::JTFootprintOnlyPIC
) {
256 CG
.reset(new BinaryFunctionCallGraph(buildCallGraph(BC
)));
257 RA
.reset(new RegAnalysis(BC
, &BC
.getBinaryFunctions(), &*CG
));
259 for (auto &BFIt
: BC
.getBinaryFunctions()) {
260 BinaryFunction
&Function
= BFIt
.second
;
262 if (!Function
.isSimple() || Function
.isIgnored())
265 if (Function
.getKnownExecutionCount() == 0)
268 DataflowInfoManager
Info(Function
, RA
.get(), nullptr);
269 BlacklistedJTs
.clear();
270 checkOpportunities(Function
, Info
);
271 optimizeFunction(Function
, Info
);
274 if (TotalJTs
== TotalJTsDenied
) {
275 outs() << "BOLT-INFO: JT Footprint reduction: no changes were made.\n";
279 outs() << "BOLT-INFO: JT Footprint reduction stats (simple funcs only):\n";
281 outs() << format("\t %.2lf%%", (OptimizedScore
* 100.0 / TotalJTScore
))
282 << " of dynamic JT entries were reduced.\n";
283 outs() << "\t " << TotalJTs
- TotalJTsDenied
<< " of " << TotalJTs
284 << " jump tables affected.\n";
285 outs() << "\t " << IndJmps
- IndJmpsDenied
<< " of " << IndJmps
286 << " indirect jumps to JTs affected.\n";
287 outs() << "\t " << NumJTsBadMatch
288 << " JTs discarded due to unsupported jump pattern.\n";
289 outs() << "\t " << NumJTsNoReg
290 << " JTs discarded due to register unavailability.\n";
291 outs() << "\t " << BytesSaved
<< " bytes saved.\n";