1 //===- GCNRegPressure.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
32 #define DEBUG_TYPE "machine-scheduler"
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
36 void llvm::printLivesAt(SlotIndex SI
,
37 const LiveIntervals
&LIS
,
38 const MachineRegisterInfo
&MRI
) {
39 dbgs() << "Live regs at " << SI
<< ": "
40 << *LIS
.getInstructionFromIndex(SI
);
42 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
43 const unsigned Reg
= Register::index2VirtReg(I
);
44 if (!LIS
.hasInterval(Reg
))
46 const auto &LI
= LIS
.getInterval(Reg
);
47 if (LI
.hasSubRanges()) {
48 bool firstTime
= true;
49 for (const auto &S
: LI
.subranges()) {
50 if (!S
.liveAt(SI
)) continue;
52 dbgs() << " " << printReg(Reg
, MRI
.getTargetRegisterInfo())
56 dbgs() << " " << S
<< '\n';
59 } else if (LI
.liveAt(SI
)) {
60 dbgs() << " " << LI
<< '\n';
64 if (!Num
) dbgs() << " <none>\n";
68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet
&S1
,
69 const GCNRPTracker::LiveRegSet
&S2
) {
70 if (S1
.size() != S2
.size())
73 for (const auto &P
: S1
) {
74 auto I
= S2
.find(P
.first
);
75 if (I
== S2
.end() || I
->second
!= P
.second
)
82 ///////////////////////////////////////////////////////////////////////////////
85 unsigned GCNRegPressure::getRegKind(unsigned Reg
,
86 const MachineRegisterInfo
&MRI
) {
87 assert(Register::isVirtualRegister(Reg
));
88 const auto RC
= MRI
.getRegClass(Reg
);
89 auto STI
= static_cast<const SIRegisterInfo
*>(MRI
.getTargetRegisterInfo());
90 return STI
->isSGPRClass(RC
) ?
91 (STI
->getRegSizeInBits(*RC
) == 32 ? SGPR32
: SGPR_TUPLE
) :
93 (STI
->getRegSizeInBits(*RC
) == 32 ? AGPR32
: AGPR_TUPLE
) :
94 (STI
->getRegSizeInBits(*RC
) == 32 ? VGPR32
: VGPR_TUPLE
);
97 void GCNRegPressure::inc(unsigned Reg
,
100 const MachineRegisterInfo
&MRI
) {
101 if (NewMask
== PrevMask
)
105 if (NewMask
< PrevMask
) {
106 std::swap(NewMask
, PrevMask
);
110 const auto MaxMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
112 switch (auto Kind
= getRegKind(Reg
, MRI
)) {
116 assert(PrevMask
.none() && NewMask
== MaxMask
);
123 assert(NewMask
< MaxMask
|| NewMask
== MaxMask
);
124 assert(PrevMask
< NewMask
);
126 Value
[Kind
== SGPR_TUPLE
? SGPR32
: Kind
== AGPR_TUPLE
? AGPR32
: VGPR32
] +=
127 Sign
* (~PrevMask
& NewMask
).getNumLanes();
129 if (PrevMask
.none()) {
130 assert(NewMask
.any());
131 Value
[Kind
] += Sign
* MRI
.getPressureSets(Reg
).getWeight();
135 default: llvm_unreachable("Unknown register kind");
139 bool GCNRegPressure::less(const GCNSubtarget
&ST
,
140 const GCNRegPressure
& O
,
141 unsigned MaxOccupancy
) const {
142 const auto SGPROcc
= std::min(MaxOccupancy
,
143 ST
.getOccupancyWithNumSGPRs(getSGPRNum()));
144 const auto VGPROcc
= std::min(MaxOccupancy
,
145 ST
.getOccupancyWithNumVGPRs(getVGPRNum()));
146 const auto OtherSGPROcc
= std::min(MaxOccupancy
,
147 ST
.getOccupancyWithNumSGPRs(O
.getSGPRNum()));
148 const auto OtherVGPROcc
= std::min(MaxOccupancy
,
149 ST
.getOccupancyWithNumVGPRs(O
.getVGPRNum()));
151 const auto Occ
= std::min(SGPROcc
, VGPROcc
);
152 const auto OtherOcc
= std::min(OtherSGPROcc
, OtherVGPROcc
);
154 return Occ
> OtherOcc
;
156 bool SGPRImportant
= SGPROcc
< VGPROcc
;
157 const bool OtherSGPRImportant
= OtherSGPROcc
< OtherVGPROcc
;
159 // if both pressures disagree on what is more important compare vgprs
160 if (SGPRImportant
!= OtherSGPRImportant
) {
161 SGPRImportant
= false;
164 // compare large regs pressure
165 bool SGPRFirst
= SGPRImportant
;
166 for (int I
= 2; I
> 0; --I
, SGPRFirst
= !SGPRFirst
) {
168 auto SW
= getSGPRTuplesWeight();
169 auto OtherSW
= O
.getSGPRTuplesWeight();
173 auto VW
= getVGPRTuplesWeight();
174 auto OtherVW
= O
.getVGPRTuplesWeight();
179 return SGPRImportant
? (getSGPRNum() < O
.getSGPRNum()):
180 (getVGPRNum() < O
.getVGPRNum());
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
185 void GCNRegPressure::print(raw_ostream
&OS
, const GCNSubtarget
*ST
) const {
186 OS
<< "VGPRs: " << Value
[VGPR32
] << ' ';
187 OS
<< "AGPRs: " << Value
[AGPR32
];
188 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
189 OS
<< ", SGPRs: " << getSGPRNum();
190 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
191 OS
<< ", LVGPR WT: " << getVGPRTuplesWeight()
192 << ", LSGPR WT: " << getSGPRTuplesWeight();
193 if (ST
) OS
<< " -> Occ: " << getOccupancy(*ST
);
198 static LaneBitmask
getDefRegMask(const MachineOperand
&MO
,
199 const MachineRegisterInfo
&MRI
) {
200 assert(MO
.isDef() && MO
.isReg() && Register::isVirtualRegister(MO
.getReg()));
202 // We don't rely on read-undef flag because in case of tentative schedule
203 // tracking it isn't set correctly yet. This works correctly however since
204 // use mask has been tracked before using LIS.
205 return MO
.getSubReg() == 0 ?
206 MRI
.getMaxLaneMaskForVReg(MO
.getReg()) :
207 MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO
.getSubReg());
210 static LaneBitmask
getUsedRegMask(const MachineOperand
&MO
,
211 const MachineRegisterInfo
&MRI
,
212 const LiveIntervals
&LIS
) {
213 assert(MO
.isUse() && MO
.isReg() && Register::isVirtualRegister(MO
.getReg()));
215 if (auto SubReg
= MO
.getSubReg())
216 return MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg
);
218 auto MaxMask
= MRI
.getMaxLaneMaskForVReg(MO
.getReg());
219 if (MaxMask
== LaneBitmask::getLane(0)) // cannot have subregs
222 // For a tentative schedule LIS isn't updated yet but livemask should remain
223 // the same on any schedule. Subreg defs can be reordered but they all must
224 // dominate uses anyway.
225 auto SI
= LIS
.getInstructionIndex(*MO
.getParent()).getBaseIndex();
226 return getLiveLaneMask(MO
.getReg(), SI
, LIS
, MRI
);
229 static SmallVector
<RegisterMaskPair
, 8>
230 collectVirtualRegUses(const MachineInstr
&MI
, const LiveIntervals
&LIS
,
231 const MachineRegisterInfo
&MRI
) {
232 SmallVector
<RegisterMaskPair
, 8> Res
;
233 for (const auto &MO
: MI
.operands()) {
234 if (!MO
.isReg() || !Register::isVirtualRegister(MO
.getReg()))
236 if (!MO
.isUse() || !MO
.readsReg())
239 auto const UsedMask
= getUsedRegMask(MO
, MRI
, LIS
);
241 auto Reg
= MO
.getReg();
242 auto I
= std::find_if(Res
.begin(), Res
.end(), [Reg
](const RegisterMaskPair
&RM
) {
243 return RM
.RegUnit
== Reg
;
246 I
->LaneMask
|= UsedMask
;
248 Res
.push_back(RegisterMaskPair(Reg
, UsedMask
));
253 ///////////////////////////////////////////////////////////////////////////////
256 LaneBitmask
llvm::getLiveLaneMask(unsigned Reg
,
258 const LiveIntervals
&LIS
,
259 const MachineRegisterInfo
&MRI
) {
260 LaneBitmask LiveMask
;
261 const auto &LI
= LIS
.getInterval(Reg
);
262 if (LI
.hasSubRanges()) {
263 for (const auto &S
: LI
.subranges())
265 LiveMask
|= S
.LaneMask
;
266 assert(LiveMask
< MRI
.getMaxLaneMaskForVReg(Reg
) ||
267 LiveMask
== MRI
.getMaxLaneMaskForVReg(Reg
));
269 } else if (LI
.liveAt(SI
)) {
270 LiveMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
275 GCNRPTracker::LiveRegSet
llvm::getLiveRegs(SlotIndex SI
,
276 const LiveIntervals
&LIS
,
277 const MachineRegisterInfo
&MRI
) {
278 GCNRPTracker::LiveRegSet LiveRegs
;
279 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
280 auto Reg
= Register::index2VirtReg(I
);
281 if (!LIS
.hasInterval(Reg
))
283 auto LiveMask
= getLiveLaneMask(Reg
, SI
, LIS
, MRI
);
285 LiveRegs
[Reg
] = LiveMask
;
290 void GCNRPTracker::reset(const MachineInstr
&MI
,
291 const LiveRegSet
*LiveRegsCopy
,
293 const MachineFunction
&MF
= *MI
.getMF();
294 MRI
= &MF
.getRegInfo();
296 if (&LiveRegs
!= LiveRegsCopy
)
297 LiveRegs
= *LiveRegsCopy
;
299 LiveRegs
= After
? getLiveRegsAfter(MI
, LIS
)
300 : getLiveRegsBefore(MI
, LIS
);
303 MaxPressure
= CurPressure
= getRegPressure(*MRI
, LiveRegs
);
306 void GCNUpwardRPTracker::reset(const MachineInstr
&MI
,
307 const LiveRegSet
*LiveRegsCopy
) {
308 GCNRPTracker::reset(MI
, LiveRegsCopy
, true);
311 void GCNUpwardRPTracker::recede(const MachineInstr
&MI
) {
312 assert(MRI
&& "call reset first");
316 if (MI
.isDebugInstr())
319 auto const RegUses
= collectVirtualRegUses(MI
, LIS
, *MRI
);
321 // calc pressure at the MI (defs + uses)
322 auto AtMIPressure
= CurPressure
;
323 for (const auto &U
: RegUses
) {
324 auto LiveMask
= LiveRegs
[U
.RegUnit
];
325 AtMIPressure
.inc(U
.RegUnit
, LiveMask
, LiveMask
| U
.LaneMask
, *MRI
);
327 // update max pressure
328 MaxPressure
= max(AtMIPressure
, MaxPressure
);
330 for (const auto &MO
: MI
.defs()) {
331 if (!MO
.isReg() || !Register::isVirtualRegister(MO
.getReg()) || MO
.isDead())
334 auto Reg
= MO
.getReg();
335 auto I
= LiveRegs
.find(Reg
);
336 if (I
== LiveRegs
.end())
338 auto &LiveMask
= I
->second
;
339 auto PrevMask
= LiveMask
;
340 LiveMask
&= ~getDefRegMask(MO
, *MRI
);
341 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
345 for (const auto &U
: RegUses
) {
346 auto &LiveMask
= LiveRegs
[U
.RegUnit
];
347 auto PrevMask
= LiveMask
;
348 LiveMask
|= U
.LaneMask
;
349 CurPressure
.inc(U
.RegUnit
, PrevMask
, LiveMask
, *MRI
);
351 assert(CurPressure
== getRegPressure(*MRI
, LiveRegs
));
354 bool GCNDownwardRPTracker::reset(const MachineInstr
&MI
,
355 const LiveRegSet
*LiveRegsCopy
) {
356 MRI
= &MI
.getParent()->getParent()->getRegInfo();
357 LastTrackedMI
= nullptr;
358 MBBEnd
= MI
.getParent()->end();
360 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
361 if (NextMI
== MBBEnd
)
363 GCNRPTracker::reset(*NextMI
, LiveRegsCopy
, false);
367 bool GCNDownwardRPTracker::advanceBeforeNext() {
368 assert(MRI
&& "call reset first");
370 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
371 if (NextMI
== MBBEnd
)
374 SlotIndex SI
= LIS
.getInstructionIndex(*NextMI
).getBaseIndex();
375 assert(SI
.isValid());
377 // Remove dead registers or mask bits.
378 for (auto &It
: LiveRegs
) {
379 const LiveInterval
&LI
= LIS
.getInterval(It
.first
);
380 if (LI
.hasSubRanges()) {
381 for (const auto &S
: LI
.subranges()) {
383 auto PrevMask
= It
.second
;
384 It
.second
&= ~S
.LaneMask
;
385 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
388 } else if (!LI
.liveAt(SI
)) {
389 auto PrevMask
= It
.second
;
390 It
.second
= LaneBitmask::getNone();
391 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
393 if (It
.second
.none())
394 LiveRegs
.erase(It
.first
);
397 MaxPressure
= max(MaxPressure
, CurPressure
);
402 void GCNDownwardRPTracker::advanceToNext() {
403 LastTrackedMI
= &*NextMI
++;
405 // Add new registers or mask bits.
406 for (const auto &MO
: LastTrackedMI
->defs()) {
409 Register Reg
= MO
.getReg();
410 if (!Register::isVirtualRegister(Reg
))
412 auto &LiveMask
= LiveRegs
[Reg
];
413 auto PrevMask
= LiveMask
;
414 LiveMask
|= getDefRegMask(MO
, *MRI
);
415 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
418 MaxPressure
= max(MaxPressure
, CurPressure
);
421 bool GCNDownwardRPTracker::advance() {
422 // If we have just called reset live set is actual.
423 if ((NextMI
== MBBEnd
) || (LastTrackedMI
&& !advanceBeforeNext()))
429 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End
) {
430 while (NextMI
!= End
)
431 if (!advance()) return false;
435 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin
,
436 MachineBasicBlock::const_iterator End
,
437 const LiveRegSet
*LiveRegsCopy
) {
438 reset(*Begin
, LiveRegsCopy
);
442 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
444 static void reportMismatch(const GCNRPTracker::LiveRegSet
&LISLR
,
445 const GCNRPTracker::LiveRegSet
&TrackedLR
,
446 const TargetRegisterInfo
*TRI
) {
447 for (auto const &P
: TrackedLR
) {
448 auto I
= LISLR
.find(P
.first
);
449 if (I
== LISLR
.end()) {
450 dbgs() << " " << printReg(P
.first
, TRI
)
451 << ":L" << PrintLaneMask(P
.second
)
452 << " isn't found in LIS reported set\n";
454 else if (I
->second
!= P
.second
) {
455 dbgs() << " " << printReg(P
.first
, TRI
)
456 << " masks doesn't match: LIS reported "
457 << PrintLaneMask(I
->second
)
459 << PrintLaneMask(P
.second
)
463 for (auto const &P
: LISLR
) {
464 auto I
= TrackedLR
.find(P
.first
);
465 if (I
== TrackedLR
.end()) {
466 dbgs() << " " << printReg(P
.first
, TRI
)
467 << ":L" << PrintLaneMask(P
.second
)
468 << " isn't found in tracked set\n";
473 bool GCNUpwardRPTracker::isValid() const {
474 const auto &SI
= LIS
.getInstructionIndex(*LastTrackedMI
).getBaseIndex();
475 const auto LISLR
= llvm::getLiveRegs(SI
, LIS
, *MRI
);
476 const auto &TrackedLR
= LiveRegs
;
478 if (!isEqual(LISLR
, TrackedLR
)) {
479 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
480 " LIS reported livesets mismatch:\n";
481 printLivesAt(SI
, LIS
, *MRI
);
482 reportMismatch(LISLR
, TrackedLR
, MRI
->getTargetRegisterInfo());
486 auto LISPressure
= getRegPressure(*MRI
, LISLR
);
487 if (LISPressure
!= CurPressure
) {
488 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
489 CurPressure
.print(dbgs());
490 dbgs() << "LIS rpt: ";
491 LISPressure
.print(dbgs());
497 void GCNRPTracker::printLiveRegs(raw_ostream
&OS
, const LiveRegSet
& LiveRegs
,
498 const MachineRegisterInfo
&MRI
) {
499 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
500 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
501 unsigned Reg
= Register::index2VirtReg(I
);
502 auto It
= LiveRegs
.find(Reg
);
503 if (It
!= LiveRegs
.end() && It
->second
.any())
504 OS
<< ' ' << printVRegOrUnit(Reg
, TRI
) << ':'
505 << PrintLaneMask(It
->second
);