[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / AArch64 / AArch64CompressJumpTables.cpp
blob7d14d2d20bad33bc8f1f49ac8e4ae2cf6cdace42
1 //==-- AArch64CompressJumpTables.cpp - Compress jump tables for AArch64 --====//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // This pass looks at the basic blocks each jump-table refers to and works out
8 // whether they can be emitted in a compressed form (with 8 or 16-bit
9 // entries). If so, it changes the opcode and flags them in the associated
10 // AArch64FunctionInfo.
12 //===----------------------------------------------------------------------===//
14 #include "AArch64.h"
15 #include "AArch64MachineFunctionInfo.h"
16 #include "AArch64Subtarget.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/CodeGen/MachineFunctionPass.h"
19 #include "llvm/CodeGen/MachineJumpTableInfo.h"
20 #include "llvm/CodeGen/TargetInstrInfo.h"
21 #include "llvm/CodeGen/TargetSubtargetInfo.h"
22 #include "llvm/MC/MCContext.h"
23 #include "llvm/Support/Alignment.h"
24 #include "llvm/Support/Debug.h"
26 using namespace llvm;
28 #define DEBUG_TYPE "aarch64-jump-tables"
30 STATISTIC(NumJT8, "Number of jump-tables with 1-byte entries");
31 STATISTIC(NumJT16, "Number of jump-tables with 2-byte entries");
32 STATISTIC(NumJT32, "Number of jump-tables with 4-byte entries");
34 namespace {
35 class AArch64CompressJumpTables : public MachineFunctionPass {
36 const TargetInstrInfo *TII;
37 MachineFunction *MF;
38 SmallVector<int, 8> BlockInfo;
40 /// Returns the size of instructions in the block \p MBB, or std::nullopt if
41 /// we couldn't get a safe upper bound.
42 std::optional<int> computeBlockSize(MachineBasicBlock &MBB);
44 /// Gather information about the function, returns false if we can't perform
45 /// this optimization for some reason.
46 bool scanFunction();
48 bool compressJumpTable(MachineInstr &MI, int Offset);
50 public:
51 static char ID;
52 AArch64CompressJumpTables() : MachineFunctionPass(ID) {
53 initializeAArch64CompressJumpTablesPass(*PassRegistry::getPassRegistry());
56 bool runOnMachineFunction(MachineFunction &MF) override;
58 MachineFunctionProperties getRequiredProperties() const override {
59 return MachineFunctionProperties().set(
60 MachineFunctionProperties::Property::NoVRegs);
62 StringRef getPassName() const override {
63 return "AArch64 Compress Jump Tables";
66 char AArch64CompressJumpTables::ID = 0;
67 } // namespace
69 INITIALIZE_PASS(AArch64CompressJumpTables, DEBUG_TYPE,
70 "AArch64 compress jump tables pass", false, false)
72 std::optional<int>
73 AArch64CompressJumpTables::computeBlockSize(MachineBasicBlock &MBB) {
74 int Size = 0;
75 for (const MachineInstr &MI : MBB) {
76 // Inline asm may contain some directives like .bytes which we don't
77 // currently have the ability to parse accurately. To be safe, just avoid
78 // computing a size and bail out.
79 if (MI.getOpcode() == AArch64::INLINEASM ||
80 MI.getOpcode() == AArch64::INLINEASM_BR)
81 return std::nullopt;
82 Size += TII->getInstSizeInBytes(MI);
84 return Size;
87 bool AArch64CompressJumpTables::scanFunction() {
88 BlockInfo.clear();
89 BlockInfo.resize(MF->getNumBlockIDs());
91 // NOTE: BlockSize, Offset, OffsetAfterAlignment are all upper bounds.
93 unsigned Offset = 0;
94 for (MachineBasicBlock &MBB : *MF) {
95 const Align Alignment = MBB.getAlignment();
96 unsigned OffsetAfterAlignment = Offset;
97 // We don't know the exact size of MBB so assume worse case padding.
98 if (Alignment > Align(4))
99 OffsetAfterAlignment += Alignment.value() - 4;
100 BlockInfo[MBB.getNumber()] = OffsetAfterAlignment;
101 auto BlockSize = computeBlockSize(MBB);
102 if (!BlockSize)
103 return false;
104 Offset = OffsetAfterAlignment + *BlockSize;
106 return true;
109 bool AArch64CompressJumpTables::compressJumpTable(MachineInstr &MI,
110 int Offset) {
111 if (MI.getOpcode() != AArch64::JumpTableDest32)
112 return false;
114 int JTIdx = MI.getOperand(4).getIndex();
115 auto &JTInfo = *MF->getJumpTableInfo();
116 const MachineJumpTableEntry &JT = JTInfo.getJumpTables()[JTIdx];
118 // The jump-table might have been optimized away.
119 if (JT.MBBs.empty())
120 return false;
122 int MaxOffset = std::numeric_limits<int>::min(),
123 MinOffset = std::numeric_limits<int>::max();
124 MachineBasicBlock *MinBlock = nullptr;
125 for (auto *Block : JT.MBBs) {
126 int BlockOffset = BlockInfo[Block->getNumber()];
127 assert(BlockOffset % 4 == 0 && "misaligned basic block");
129 MaxOffset = std::max(MaxOffset, BlockOffset);
130 if (BlockOffset <= MinOffset) {
131 MinOffset = BlockOffset;
132 MinBlock = Block;
135 assert(MinBlock && "Failed to find minimum offset block");
137 // The ADR instruction needed to calculate the address of the first reachable
138 // basic block can address +/-1MB.
139 if (!isInt<21>(MinOffset - Offset)) {
140 ++NumJT32;
141 return false;
144 int Span = MaxOffset - MinOffset;
145 auto *AFI = MF->getInfo<AArch64FunctionInfo>();
146 if (isUInt<8>(Span / 4)) {
147 AFI->setJumpTableEntryInfo(JTIdx, 1, MinBlock->getSymbol());
148 MI.setDesc(TII->get(AArch64::JumpTableDest8));
149 ++NumJT8;
150 return true;
152 if (isUInt<16>(Span / 4)) {
153 AFI->setJumpTableEntryInfo(JTIdx, 2, MinBlock->getSymbol());
154 MI.setDesc(TII->get(AArch64::JumpTableDest16));
155 ++NumJT16;
156 return true;
159 ++NumJT32;
160 return false;
163 bool AArch64CompressJumpTables::runOnMachineFunction(MachineFunction &MFIn) {
164 bool Changed = false;
165 MF = &MFIn;
167 const auto &ST = MF->getSubtarget<AArch64Subtarget>();
168 TII = ST.getInstrInfo();
170 if (ST.force32BitJumpTables() && !MF->getFunction().hasMinSize())
171 return false;
173 if (!scanFunction())
174 return false;
176 for (MachineBasicBlock &MBB : *MF) {
177 int Offset = BlockInfo[MBB.getNumber()];
178 for (MachineInstr &MI : MBB) {
179 Changed |= compressJumpTable(MI, Offset);
180 Offset += TII->getInstSizeInBytes(MI);
184 return Changed;
187 FunctionPass *llvm::createAArch64CompressJumpTablesPass() {
188 return new AArch64CompressJumpTables();