1 //===---- X86CondBrFolding.cpp - optimize conditional branches ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This file defines a pass that optimizes condition branches on x86 by taking
9 // advantage of the three-way conditional code generated by compare
11 // Currently, it tries to hoisting EQ and NE conditional branch to a dominant
12 // conditional branch condition where the same EQ/NE conditional code is
13 // computed. An example:
26 // Here we could combine the two compares in bb_0 and bb_4 and have the
37 // For the case of %0 == 20 (bb_5), we eliminate two jumps, and the control
38 // height for bb_6 is also reduced. bb_4 is gone after the optimization.
40 // There are plenty of this code patterns, especially from the switch case
41 // lowing where we generate compare of "pivot-1" for the inner nodes in the
42 // binary search tree.
43 //===----------------------------------------------------------------------===//
46 #include "X86InstrInfo.h"
47 #include "X86Subtarget.h"
48 #include "llvm/ADT/Statistic.h"
49 #include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
50 #include "llvm/CodeGen/MachineFunctionPass.h"
51 #include "llvm/CodeGen/MachineInstrBuilder.h"
52 #include "llvm/CodeGen/MachineRegisterInfo.h"
53 #include "llvm/Support/BranchProbability.h"
57 #define DEBUG_TYPE "x86-condbr-folding"
59 STATISTIC(NumFixedCondBrs
, "Number of x86 condbr folded");
62 class X86CondBrFoldingPass
: public MachineFunctionPass
{
64 X86CondBrFoldingPass() : MachineFunctionPass(ID
) { }
65 StringRef
getPassName() const override
{ return "X86 CondBr Folding"; }
67 bool runOnMachineFunction(MachineFunction
&MF
) override
;
69 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
70 MachineFunctionPass::getAnalysisUsage(AU
);
71 AU
.addRequired
<MachineBranchProbabilityInfo
>();
79 char X86CondBrFoldingPass::ID
= 0;
80 INITIALIZE_PASS(X86CondBrFoldingPass
, "X86CondBrFolding", "X86CondBrFolding", false, false)
82 FunctionPass
*llvm::createX86CondBrFolding() {
83 return new X86CondBrFoldingPass();
87 // A class the stores the auxiliary information for each MBB.
88 struct TargetMBBInfo
{
89 MachineBasicBlock
*TBB
;
90 MachineBasicBlock
*FBB
;
91 MachineInstr
*BrInstr
;
92 MachineInstr
*CmpInstr
;
93 X86::CondCode BranchCode
;
100 // A class that optimizes the conditional branch by hoisting and merge CondCode.
101 class X86CondBrFolding
{
103 X86CondBrFolding(const X86InstrInfo
*TII
,
104 const MachineBranchProbabilityInfo
*MBPI
,
106 : TII(TII
), MBPI(MBPI
), MF(MF
) {}
110 const X86InstrInfo
*TII
;
111 const MachineBranchProbabilityInfo
*MBPI
;
113 std::vector
<std::unique_ptr
<TargetMBBInfo
>> MBBInfos
;
114 SmallVector
<MachineBasicBlock
*, 4> RemoveList
;
116 void optimizeCondBr(MachineBasicBlock
&MBB
,
117 SmallVectorImpl
<MachineBasicBlock
*> &BranchPath
);
118 void fixBranchProb(MachineBasicBlock
*NextMBB
, MachineBasicBlock
*RootMBB
,
119 SmallVectorImpl
<MachineBasicBlock
*> &BranchPath
);
120 void replaceBrDest(MachineBasicBlock
*MBB
, MachineBasicBlock
*OrigDest
,
121 MachineBasicBlock
*NewDest
);
122 void fixupModifiedCond(MachineBasicBlock
*MBB
);
123 std::unique_ptr
<TargetMBBInfo
> analyzeMBB(MachineBasicBlock
&MBB
);
124 static bool analyzeCompare(const MachineInstr
&MI
, unsigned &SrcReg
,
126 bool findPath(MachineBasicBlock
*MBB
,
127 SmallVectorImpl
<MachineBasicBlock
*> &BranchPath
);
128 TargetMBBInfo
*getMBBInfo(MachineBasicBlock
*MBB
) const {
129 return MBBInfos
[MBB
->getNumber()].get();
134 // Find a valid path that we can reuse the CondCode.
135 // The resulted path (if return true) is stored in BranchPath.
137 // false: is no valid path is found.
138 // true: a valid path is found and the targetBB can be reached.
139 bool X86CondBrFolding::findPath(
140 MachineBasicBlock
*MBB
, SmallVectorImpl
<MachineBasicBlock
*> &BranchPath
) {
141 TargetMBBInfo
*MBBInfo
= getMBBInfo(MBB
);
142 assert(MBBInfo
&& "Expecting a candidate MBB");
143 int CmpValue
= MBBInfo
->CmpValue
;
145 MachineBasicBlock
*PredMBB
= *MBB
->pred_begin();
146 MachineBasicBlock
*SaveMBB
= MBB
;
148 TargetMBBInfo
*PredMBBInfo
= getMBBInfo(PredMBB
);
149 if (!PredMBBInfo
|| PredMBBInfo
->SrcReg
!= MBBInfo
->SrcReg
)
152 assert(SaveMBB
== PredMBBInfo
->TBB
|| SaveMBB
== PredMBBInfo
->FBB
);
153 bool IsFalseBranch
= (SaveMBB
== PredMBBInfo
->FBB
);
155 X86::CondCode CC
= PredMBBInfo
->BranchCode
;
156 assert(CC
== X86::COND_L
|| CC
== X86::COND_G
|| CC
== X86::COND_E
);
157 int PredCmpValue
= PredMBBInfo
->CmpValue
;
158 bool ValueCmpTrue
= ((CmpValue
< PredCmpValue
&& CC
== X86::COND_L
) ||
159 (CmpValue
> PredCmpValue
&& CC
== X86::COND_G
) ||
160 (CmpValue
== PredCmpValue
&& CC
== X86::COND_E
));
161 // Check if both the result of value compare and the branch target match.
162 if (!(ValueCmpTrue
^ IsFalseBranch
)) {
163 LLVM_DEBUG(dbgs() << "Dead BB detected!\n");
167 BranchPath
.push_back(PredMBB
);
168 // These are the conditions on which we could combine the compares.
169 if ((CmpValue
== PredCmpValue
) ||
170 (CmpValue
== PredCmpValue
- 1 && CC
== X86::COND_L
) ||
171 (CmpValue
== PredCmpValue
+ 1 && CC
== X86::COND_G
))
174 // If PredMBB has more than on preds, or not a pure cmp and br, we bailout.
175 if (PredMBB
->pred_size() != 1 || !PredMBBInfo
->CmpBrOnly
)
179 PredMBB
= *PredMBB
->pred_begin();
184 // Fix up any PHI node in the successor of MBB.
185 static void fixPHIsInSucc(MachineBasicBlock
*MBB
, MachineBasicBlock
*OldMBB
,
186 MachineBasicBlock
*NewMBB
) {
187 if (NewMBB
== OldMBB
)
189 for (auto MI
= MBB
->instr_begin(), ME
= MBB
->instr_end();
190 MI
!= ME
&& MI
->isPHI(); ++MI
)
191 for (unsigned i
= 2, e
= MI
->getNumOperands() + 1; i
!= e
; i
+= 2) {
192 MachineOperand
&MO
= MI
->getOperand(i
);
193 if (MO
.getMBB() == OldMBB
)
198 // Utility function to set branch probability for edge MBB->SuccMBB.
199 static inline bool setBranchProb(MachineBasicBlock
*MBB
,
200 MachineBasicBlock
*SuccMBB
,
201 BranchProbability Prob
) {
202 auto MBBI
= std::find(MBB
->succ_begin(), MBB
->succ_end(), SuccMBB
);
203 if (MBBI
== MBB
->succ_end())
205 MBB
->setSuccProbability(MBBI
, Prob
);
209 // Utility function to find the unconditional br instruction in MBB.
210 static inline MachineBasicBlock::iterator
211 findUncondBrI(MachineBasicBlock
*MBB
) {
212 return std::find_if(MBB
->begin(), MBB
->end(), [](MachineInstr
&MI
) -> bool {
213 return MI
.getOpcode() == X86::JMP_1
;
217 // Replace MBB's original successor, OrigDest, with NewDest.
218 // Also update the MBBInfo for MBB.
219 void X86CondBrFolding::replaceBrDest(MachineBasicBlock
*MBB
,
220 MachineBasicBlock
*OrigDest
,
221 MachineBasicBlock
*NewDest
) {
222 TargetMBBInfo
*MBBInfo
= getMBBInfo(MBB
);
224 if (MBBInfo
->TBB
== OrigDest
) {
225 BrMI
= MBBInfo
->BrInstr
;
226 MachineInstrBuilder MIB
=
227 BuildMI(*MBB
, BrMI
, MBB
->findDebugLoc(BrMI
), TII
->get(X86::JCC_1
))
228 .addMBB(NewDest
).addImm(MBBInfo
->BranchCode
);
229 MBBInfo
->TBB
= NewDest
;
230 MBBInfo
->BrInstr
= MIB
.getInstr();
231 } else { // Should be the unconditional jump stmt.
232 MachineBasicBlock::iterator UncondBrI
= findUncondBrI(MBB
);
233 BuildMI(*MBB
, UncondBrI
, MBB
->findDebugLoc(UncondBrI
), TII
->get(X86::JMP_1
))
235 MBBInfo
->FBB
= NewDest
;
238 fixPHIsInSucc(NewDest
, OrigDest
, MBB
);
239 BrMI
->eraseFromParent();
240 MBB
->addSuccessor(NewDest
);
241 setBranchProb(MBB
, NewDest
, MBPI
->getEdgeProbability(MBB
, OrigDest
));
242 MBB
->removeSuccessor(OrigDest
);
245 // Change the CondCode and BrInstr according to MBBInfo.
246 void X86CondBrFolding::fixupModifiedCond(MachineBasicBlock
*MBB
) {
247 TargetMBBInfo
*MBBInfo
= getMBBInfo(MBB
);
248 if (!MBBInfo
->Modified
)
251 MachineInstr
*BrMI
= MBBInfo
->BrInstr
;
252 X86::CondCode CC
= MBBInfo
->BranchCode
;
253 MachineInstrBuilder MIB
= BuildMI(*MBB
, BrMI
, MBB
->findDebugLoc(BrMI
),
254 TII
->get(X86::JCC_1
))
255 .addMBB(MBBInfo
->TBB
).addImm(CC
);
256 BrMI
->eraseFromParent();
257 MBBInfo
->BrInstr
= MIB
.getInstr();
259 MachineBasicBlock::iterator UncondBrI
= findUncondBrI(MBB
);
260 BuildMI(*MBB
, UncondBrI
, MBB
->findDebugLoc(UncondBrI
), TII
->get(X86::JMP_1
))
261 .addMBB(MBBInfo
->FBB
);
262 MBB
->erase(UncondBrI
);
263 MBBInfo
->Modified
= false;
267 // Apply the transformation:
268 // RootMBB -1-> ... PredMBB -3-> MBB -5-> TargetMBB
269 // \-2-> \-4-> \-6-> FalseMBB
271 // RootMBB -1-> ... PredMBB -7-> FalseMBB
272 // TargetMBB <-8-/ \-2-> \-4->
274 // Note that PredMBB and RootMBB could be the same.
275 // And in the case of dead TargetMBB, we will not have TargetMBB and edge 8.
277 // There are some special handling where the RootMBB is COND_E in which case
278 // we directly short-cycle the brinstr.
280 void X86CondBrFolding::optimizeCondBr(
281 MachineBasicBlock
&MBB
, SmallVectorImpl
<MachineBasicBlock
*> &BranchPath
) {
284 TargetMBBInfo
*MBBInfo
= getMBBInfo(&MBB
);
285 assert(MBBInfo
&& "Expecting a candidate MBB");
286 MachineBasicBlock
*TargetMBB
= MBBInfo
->TBB
;
287 BranchProbability TargetProb
= MBPI
->getEdgeProbability(&MBB
, MBBInfo
->TBB
);
289 // Forward the jump from MBB's predecessor to MBB's false target.
290 MachineBasicBlock
*PredMBB
= BranchPath
.front();
291 TargetMBBInfo
*PredMBBInfo
= getMBBInfo(PredMBB
);
292 assert(PredMBBInfo
&& "Expecting a candidate MBB");
293 if (PredMBBInfo
->Modified
)
294 fixupModifiedCond(PredMBB
);
295 CC
= PredMBBInfo
->BranchCode
;
296 // Don't do this if depth of BranchPath is 1 and PredMBB is of COND_E.
297 // We will short-cycle directly for this case.
298 if (!(CC
== X86::COND_E
&& BranchPath
.size() == 1))
299 replaceBrDest(PredMBB
, &MBB
, MBBInfo
->FBB
);
301 MachineBasicBlock
*RootMBB
= BranchPath
.back();
302 TargetMBBInfo
*RootMBBInfo
= getMBBInfo(RootMBB
);
303 assert(RootMBBInfo
&& "Expecting a candidate MBB");
304 if (RootMBBInfo
->Modified
)
305 fixupModifiedCond(RootMBB
);
306 CC
= RootMBBInfo
->BranchCode
;
308 if (CC
!= X86::COND_E
) {
309 MachineBasicBlock::iterator UncondBrI
= findUncondBrI(RootMBB
);
310 // RootMBB: Cond jump to the original not-taken MBB.
320 llvm_unreachable("unexpected condtional code.");
322 BuildMI(*RootMBB
, UncondBrI
, RootMBB
->findDebugLoc(UncondBrI
),
323 TII
->get(X86::JCC_1
))
324 .addMBB(RootMBBInfo
->FBB
).addImm(NewCC
);
326 // RootMBB: Jump to TargetMBB
327 BuildMI(*RootMBB
, UncondBrI
, RootMBB
->findDebugLoc(UncondBrI
),
328 TII
->get(X86::JMP_1
))
330 RootMBB
->addSuccessor(TargetMBB
);
331 fixPHIsInSucc(TargetMBB
, &MBB
, RootMBB
);
332 RootMBB
->erase(UncondBrI
);
334 replaceBrDest(RootMBB
, RootMBBInfo
->TBB
, TargetMBB
);
337 // Fix RootMBB's CmpValue to MBB's CmpValue to TargetMBB. Don't set Imm
338 // directly. Move MBB's stmt to here as the opcode might be different.
339 if (RootMBBInfo
->CmpValue
!= MBBInfo
->CmpValue
) {
340 MachineInstr
*NewCmp
= MBBInfo
->CmpInstr
;
341 NewCmp
->removeFromParent();
342 RootMBB
->insert(RootMBBInfo
->CmpInstr
, NewCmp
);
343 RootMBBInfo
->CmpInstr
->eraseFromParent();
346 // Fix branch Probabilities.
347 auto fixBranchProb
= [&](MachineBasicBlock
*NextMBB
) {
348 BranchProbability Prob
;
349 for (auto &I
: BranchPath
) {
350 MachineBasicBlock
*ThisMBB
= I
;
351 if (!ThisMBB
->hasSuccessorProbabilities() ||
352 !ThisMBB
->isSuccessor(NextMBB
))
354 Prob
= MBPI
->getEdgeProbability(ThisMBB
, NextMBB
);
355 if (Prob
.isUnknown())
357 TargetProb
= Prob
* TargetProb
;
358 Prob
= Prob
- TargetProb
;
359 setBranchProb(ThisMBB
, NextMBB
, Prob
);
360 if (ThisMBB
== RootMBB
) {
361 setBranchProb(ThisMBB
, TargetMBB
, TargetProb
);
363 ThisMBB
->normalizeSuccProbs();
364 if (ThisMBB
== RootMBB
)
370 if (CC
!= X86::COND_E
&& !TargetProb
.isUnknown())
371 fixBranchProb(MBBInfo
->FBB
);
373 if (CC
!= X86::COND_E
)
374 RemoveList
.push_back(&MBB
);
376 // Invalidate MBBInfo just in case.
377 MBBInfos
[MBB
.getNumber()] = nullptr;
378 MBBInfos
[RootMBB
->getNumber()] = nullptr;
380 LLVM_DEBUG(dbgs() << "After optimization:\nRootMBB is: " << *RootMBB
<< "\n");
381 if (BranchPath
.size() > 1)
382 LLVM_DEBUG(dbgs() << "PredMBB is: " << *(BranchPath
[0]) << "\n");
385 // Driver function for optimization: find the valid candidate and apply
386 // the transformation.
387 bool X86CondBrFolding::optimize() {
388 bool Changed
= false;
389 LLVM_DEBUG(dbgs() << "***** X86CondBr Folding on Function: " << MF
.getName()
391 // Setup data structures.
392 MBBInfos
.resize(MF
.getNumBlockIDs());
394 MBBInfos
[MBB
.getNumber()] = analyzeMBB(MBB
);
396 for (auto &MBB
: MF
) {
397 TargetMBBInfo
*MBBInfo
= getMBBInfo(&MBB
);
398 if (!MBBInfo
|| !MBBInfo
->CmpBrOnly
)
400 if (MBB
.pred_size() != 1)
402 LLVM_DEBUG(dbgs() << "Work on MBB." << MBB
.getNumber()
403 << " CmpValue: " << MBBInfo
->CmpValue
<< "\n");
404 SmallVector
<MachineBasicBlock
*, 4> BranchPath
;
405 if (!findPath(&MBB
, BranchPath
))
409 LLVM_DEBUG(dbgs() << "Found one path (len=" << BranchPath
.size() << "):\n");
411 LLVM_DEBUG(dbgs() << "Target MBB is: " << MBB
<< "\n");
412 for (auto I
= BranchPath
.rbegin(); I
!= BranchPath
.rend(); ++I
, ++Index
) {
413 MachineBasicBlock
*PMBB
= *I
;
414 TargetMBBInfo
*PMBBInfo
= getMBBInfo(PMBB
);
415 LLVM_DEBUG(dbgs() << "Path MBB (" << Index
<< " of " << BranchPath
.size()
416 << ") is " << *PMBB
);
417 LLVM_DEBUG(dbgs() << "CC=" << PMBBInfo
->BranchCode
418 << " Val=" << PMBBInfo
->CmpValue
419 << " CmpBrOnly=" << PMBBInfo
->CmpBrOnly
<< "\n\n");
422 optimizeCondBr(MBB
, BranchPath
);
425 NumFixedCondBrs
+= RemoveList
.size();
426 for (auto MBBI
: RemoveList
) {
427 while (!MBBI
->succ_empty())
428 MBBI
->removeSuccessor(MBBI
->succ_end() - 1);
430 MBBI
->eraseFromParent();
436 // Analyze instructions that generate CondCode and extract information.
437 bool X86CondBrFolding::analyzeCompare(const MachineInstr
&MI
, unsigned &SrcReg
,
439 unsigned SrcRegIndex
= 0;
440 unsigned ValueIndex
= 0;
441 switch (MI
.getOpcode()) {
442 // TODO: handle test instructions.
466 SrcReg
= MI
.getOperand(SrcRegIndex
).getReg();
467 if (!MI
.getOperand(ValueIndex
).isImm())
469 CmpValue
= MI
.getOperand(ValueIndex
).getImm();
473 // Analyze a candidate MBB and set the extract all the information needed.
474 // The valid candidate will have two successors.
475 // It also should have a sequence of
479 // Return TargetMBBInfo if MBB is a valid candidate and nullptr otherwise.
480 std::unique_ptr
<TargetMBBInfo
>
481 X86CondBrFolding::analyzeMBB(MachineBasicBlock
&MBB
) {
482 MachineBasicBlock
*TBB
;
483 MachineBasicBlock
*FBB
;
484 MachineInstr
*BrInstr
;
485 MachineInstr
*CmpInstr
;
492 if (MBB
.succ_size() != 2)
498 MachineBasicBlock::iterator I
= MBB
.end();
499 while (I
!= MBB
.begin()) {
501 if (I
->isDebugValue())
503 if (I
->getOpcode() == X86::JMP_1
) {
506 FBB
= I
->getOperand(0).getMBB();
512 CC
= X86::getCondFromBranch(*I
);
524 TBB
= I
->getOperand(0).getMBB();
528 if (analyzeCompare(*I
, SrcReg
, CmpValue
)) {
538 if (!TBB
|| !FBB
|| !CmpInstr
)
541 // Simplify CondCode. Note this is only to simplify the findPath logic
542 // and will not change the instruction here.
550 if (CmpValue
== INT_MAX
)
557 if (CmpValue
== INT_MIN
)
567 return llvm::make_unique
<TargetMBBInfo
>(TargetMBBInfo
{
568 TBB
, FBB
, BrInstr
, CmpInstr
, CC
, SrcReg
, CmpValue
, Modified
, CmpBrOnly
});
571 bool X86CondBrFoldingPass::runOnMachineFunction(MachineFunction
&MF
) {
572 const X86Subtarget
&ST
= MF
.getSubtarget
<X86Subtarget
>();
573 if (!ST
.threewayBranchProfitable())
575 const X86InstrInfo
*TII
= ST
.getInstrInfo();
576 const MachineBranchProbabilityInfo
*MBPI
=
577 &getAnalysis
<MachineBranchProbabilityInfo
>();
579 X86CondBrFolding
CondBr(TII
, MBPI
, MF
);
580 return CondBr
.optimize();