1 //===- GCNRegPressure.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
32 #define DEBUG_TYPE "machine-scheduler"
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
36 void llvm::printLivesAt(SlotIndex SI
,
37 const LiveIntervals
&LIS
,
38 const MachineRegisterInfo
&MRI
) {
39 dbgs() << "Live regs at " << SI
<< ": "
40 << *LIS
.getInstructionFromIndex(SI
);
42 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
43 const unsigned Reg
= TargetRegisterInfo::index2VirtReg(I
);
44 if (!LIS
.hasInterval(Reg
))
46 const auto &LI
= LIS
.getInterval(Reg
);
47 if (LI
.hasSubRanges()) {
48 bool firstTime
= true;
49 for (const auto &S
: LI
.subranges()) {
50 if (!S
.liveAt(SI
)) continue;
52 dbgs() << " " << printReg(Reg
, MRI
.getTargetRegisterInfo())
56 dbgs() << " " << S
<< '\n';
59 } else if (LI
.liveAt(SI
)) {
60 dbgs() << " " << LI
<< '\n';
64 if (!Num
) dbgs() << " <none>\n";
67 static bool isEqual(const GCNRPTracker::LiveRegSet
&S1
,
68 const GCNRPTracker::LiveRegSet
&S2
) {
69 if (S1
.size() != S2
.size())
72 for (const auto &P
: S1
) {
73 auto I
= S2
.find(P
.first
);
74 if (I
== S2
.end() || I
->second
!= P
.second
)
81 ///////////////////////////////////////////////////////////////////////////////
84 unsigned GCNRegPressure::getRegKind(unsigned Reg
,
85 const MachineRegisterInfo
&MRI
) {
86 assert(TargetRegisterInfo::isVirtualRegister(Reg
));
87 const auto RC
= MRI
.getRegClass(Reg
);
88 auto STI
= static_cast<const SIRegisterInfo
*>(MRI
.getTargetRegisterInfo());
89 return STI
->isSGPRClass(RC
) ?
90 (STI
->getRegSizeInBits(*RC
) == 32 ? SGPR32
: SGPR_TUPLE
) :
91 (STI
->getRegSizeInBits(*RC
) == 32 ? VGPR32
: VGPR_TUPLE
);
94 void GCNRegPressure::inc(unsigned Reg
,
97 const MachineRegisterInfo
&MRI
) {
98 if (NewMask
== PrevMask
)
102 if (NewMask
< PrevMask
) {
103 std::swap(NewMask
, PrevMask
);
107 const auto MaxMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
109 switch (auto Kind
= getRegKind(Reg
, MRI
)) {
112 assert(PrevMask
.none() && NewMask
== MaxMask
);
118 assert(NewMask
< MaxMask
|| NewMask
== MaxMask
);
119 assert(PrevMask
< NewMask
);
121 Value
[Kind
== SGPR_TUPLE
? SGPR32
: VGPR32
] +=
122 Sign
* (~PrevMask
& NewMask
).getNumLanes();
124 if (PrevMask
.none()) {
125 assert(NewMask
.any());
126 Value
[Kind
] += Sign
* MRI
.getPressureSets(Reg
).getWeight();
130 default: llvm_unreachable("Unknown register kind");
134 bool GCNRegPressure::less(const GCNSubtarget
&ST
,
135 const GCNRegPressure
& O
,
136 unsigned MaxOccupancy
) const {
137 const auto SGPROcc
= std::min(MaxOccupancy
,
138 ST
.getOccupancyWithNumSGPRs(getSGPRNum()));
139 const auto VGPROcc
= std::min(MaxOccupancy
,
140 ST
.getOccupancyWithNumVGPRs(getVGPRNum()));
141 const auto OtherSGPROcc
= std::min(MaxOccupancy
,
142 ST
.getOccupancyWithNumSGPRs(O
.getSGPRNum()));
143 const auto OtherVGPROcc
= std::min(MaxOccupancy
,
144 ST
.getOccupancyWithNumVGPRs(O
.getVGPRNum()));
146 const auto Occ
= std::min(SGPROcc
, VGPROcc
);
147 const auto OtherOcc
= std::min(OtherSGPROcc
, OtherVGPROcc
);
149 return Occ
> OtherOcc
;
151 bool SGPRImportant
= SGPROcc
< VGPROcc
;
152 const bool OtherSGPRImportant
= OtherSGPROcc
< OtherVGPROcc
;
154 // if both pressures disagree on what is more important compare vgprs
155 if (SGPRImportant
!= OtherSGPRImportant
) {
156 SGPRImportant
= false;
159 // compare large regs pressure
160 bool SGPRFirst
= SGPRImportant
;
161 for (int I
= 2; I
> 0; --I
, SGPRFirst
= !SGPRFirst
) {
163 auto SW
= getSGPRTuplesWeight();
164 auto OtherSW
= O
.getSGPRTuplesWeight();
168 auto VW
= getVGPRTuplesWeight();
169 auto OtherVW
= O
.getVGPRTuplesWeight();
174 return SGPRImportant
? (getSGPRNum() < O
.getSGPRNum()):
175 (getVGPRNum() < O
.getVGPRNum());
178 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
180 void GCNRegPressure::print(raw_ostream
&OS
, const GCNSubtarget
*ST
) const {
181 OS
<< "VGPRs: " << getVGPRNum();
182 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
183 OS
<< ", SGPRs: " << getSGPRNum();
184 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
185 OS
<< ", LVGPR WT: " << getVGPRTuplesWeight()
186 << ", LSGPR WT: " << getSGPRTuplesWeight();
187 if (ST
) OS
<< " -> Occ: " << getOccupancy(*ST
);
192 static LaneBitmask
getDefRegMask(const MachineOperand
&MO
,
193 const MachineRegisterInfo
&MRI
) {
194 assert(MO
.isDef() && MO
.isReg() &&
195 TargetRegisterInfo::isVirtualRegister(MO
.getReg()));
197 // We don't rely on read-undef flag because in case of tentative schedule
198 // tracking it isn't set correctly yet. This works correctly however since
199 // use mask has been tracked before using LIS.
200 return MO
.getSubReg() == 0 ?
201 MRI
.getMaxLaneMaskForVReg(MO
.getReg()) :
202 MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO
.getSubReg());
205 static LaneBitmask
getUsedRegMask(const MachineOperand
&MO
,
206 const MachineRegisterInfo
&MRI
,
207 const LiveIntervals
&LIS
) {
208 assert(MO
.isUse() && MO
.isReg() &&
209 TargetRegisterInfo::isVirtualRegister(MO
.getReg()));
211 if (auto SubReg
= MO
.getSubReg())
212 return MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg
);
214 auto MaxMask
= MRI
.getMaxLaneMaskForVReg(MO
.getReg());
215 if (MaxMask
== LaneBitmask::getLane(0)) // cannot have subregs
218 // For a tentative schedule LIS isn't updated yet but livemask should remain
219 // the same on any schedule. Subreg defs can be reordered but they all must
220 // dominate uses anyway.
221 auto SI
= LIS
.getInstructionIndex(*MO
.getParent()).getBaseIndex();
222 return getLiveLaneMask(MO
.getReg(), SI
, LIS
, MRI
);
225 static SmallVector
<RegisterMaskPair
, 8>
226 collectVirtualRegUses(const MachineInstr
&MI
, const LiveIntervals
&LIS
,
227 const MachineRegisterInfo
&MRI
) {
228 SmallVector
<RegisterMaskPair
, 8> Res
;
229 for (const auto &MO
: MI
.operands()) {
230 if (!MO
.isReg() || !TargetRegisterInfo::isVirtualRegister(MO
.getReg()))
232 if (!MO
.isUse() || !MO
.readsReg())
235 auto const UsedMask
= getUsedRegMask(MO
, MRI
, LIS
);
237 auto Reg
= MO
.getReg();
238 auto I
= std::find_if(Res
.begin(), Res
.end(), [Reg
](const RegisterMaskPair
&RM
) {
239 return RM
.RegUnit
== Reg
;
242 I
->LaneMask
|= UsedMask
;
244 Res
.push_back(RegisterMaskPair(Reg
, UsedMask
));
249 ///////////////////////////////////////////////////////////////////////////////
252 LaneBitmask
llvm::getLiveLaneMask(unsigned Reg
,
254 const LiveIntervals
&LIS
,
255 const MachineRegisterInfo
&MRI
) {
256 LaneBitmask LiveMask
;
257 const auto &LI
= LIS
.getInterval(Reg
);
258 if (LI
.hasSubRanges()) {
259 for (const auto &S
: LI
.subranges())
261 LiveMask
|= S
.LaneMask
;
262 assert(LiveMask
< MRI
.getMaxLaneMaskForVReg(Reg
) ||
263 LiveMask
== MRI
.getMaxLaneMaskForVReg(Reg
));
265 } else if (LI
.liveAt(SI
)) {
266 LiveMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
271 GCNRPTracker::LiveRegSet
llvm::getLiveRegs(SlotIndex SI
,
272 const LiveIntervals
&LIS
,
273 const MachineRegisterInfo
&MRI
) {
274 GCNRPTracker::LiveRegSet LiveRegs
;
275 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
276 auto Reg
= TargetRegisterInfo::index2VirtReg(I
);
277 if (!LIS
.hasInterval(Reg
))
279 auto LiveMask
= getLiveLaneMask(Reg
, SI
, LIS
, MRI
);
281 LiveRegs
[Reg
] = LiveMask
;
286 void GCNRPTracker::reset(const MachineInstr
&MI
,
287 const LiveRegSet
*LiveRegsCopy
,
289 const MachineFunction
&MF
= *MI
.getMF();
290 MRI
= &MF
.getRegInfo();
292 if (&LiveRegs
!= LiveRegsCopy
)
293 LiveRegs
= *LiveRegsCopy
;
295 LiveRegs
= After
? getLiveRegsAfter(MI
, LIS
)
296 : getLiveRegsBefore(MI
, LIS
);
299 MaxPressure
= CurPressure
= getRegPressure(*MRI
, LiveRegs
);
302 void GCNUpwardRPTracker::reset(const MachineInstr
&MI
,
303 const LiveRegSet
*LiveRegsCopy
) {
304 GCNRPTracker::reset(MI
, LiveRegsCopy
, true);
307 void GCNUpwardRPTracker::recede(const MachineInstr
&MI
) {
308 assert(MRI
&& "call reset first");
312 if (MI
.isDebugInstr())
315 auto const RegUses
= collectVirtualRegUses(MI
, LIS
, *MRI
);
317 // calc pressure at the MI (defs + uses)
318 auto AtMIPressure
= CurPressure
;
319 for (const auto &U
: RegUses
) {
320 auto LiveMask
= LiveRegs
[U
.RegUnit
];
321 AtMIPressure
.inc(U
.RegUnit
, LiveMask
, LiveMask
| U
.LaneMask
, *MRI
);
323 // update max pressure
324 MaxPressure
= max(AtMIPressure
, MaxPressure
);
326 for (const auto &MO
: MI
.defs()) {
327 if (!MO
.isReg() || !TargetRegisterInfo::isVirtualRegister(MO
.getReg()) ||
331 auto Reg
= MO
.getReg();
332 auto I
= LiveRegs
.find(Reg
);
333 if (I
== LiveRegs
.end())
335 auto &LiveMask
= I
->second
;
336 auto PrevMask
= LiveMask
;
337 LiveMask
&= ~getDefRegMask(MO
, *MRI
);
338 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
342 for (const auto &U
: RegUses
) {
343 auto &LiveMask
= LiveRegs
[U
.RegUnit
];
344 auto PrevMask
= LiveMask
;
345 LiveMask
|= U
.LaneMask
;
346 CurPressure
.inc(U
.RegUnit
, PrevMask
, LiveMask
, *MRI
);
348 assert(CurPressure
== getRegPressure(*MRI
, LiveRegs
));
351 bool GCNDownwardRPTracker::reset(const MachineInstr
&MI
,
352 const LiveRegSet
*LiveRegsCopy
) {
353 MRI
= &MI
.getParent()->getParent()->getRegInfo();
354 LastTrackedMI
= nullptr;
355 MBBEnd
= MI
.getParent()->end();
357 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
358 if (NextMI
== MBBEnd
)
360 GCNRPTracker::reset(*NextMI
, LiveRegsCopy
, false);
364 bool GCNDownwardRPTracker::advanceBeforeNext() {
365 assert(MRI
&& "call reset first");
367 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
368 if (NextMI
== MBBEnd
)
371 SlotIndex SI
= LIS
.getInstructionIndex(*NextMI
).getBaseIndex();
372 assert(SI
.isValid());
374 // Remove dead registers or mask bits.
375 for (auto &It
: LiveRegs
) {
376 const LiveInterval
&LI
= LIS
.getInterval(It
.first
);
377 if (LI
.hasSubRanges()) {
378 for (const auto &S
: LI
.subranges()) {
380 auto PrevMask
= It
.second
;
381 It
.second
&= ~S
.LaneMask
;
382 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
385 } else if (!LI
.liveAt(SI
)) {
386 auto PrevMask
= It
.second
;
387 It
.second
= LaneBitmask::getNone();
388 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
390 if (It
.second
.none())
391 LiveRegs
.erase(It
.first
);
394 MaxPressure
= max(MaxPressure
, CurPressure
);
399 void GCNDownwardRPTracker::advanceToNext() {
400 LastTrackedMI
= &*NextMI
++;
402 // Add new registers or mask bits.
403 for (const auto &MO
: LastTrackedMI
->defs()) {
406 unsigned Reg
= MO
.getReg();
407 if (!TargetRegisterInfo::isVirtualRegister(Reg
))
409 auto &LiveMask
= LiveRegs
[Reg
];
410 auto PrevMask
= LiveMask
;
411 LiveMask
|= getDefRegMask(MO
, *MRI
);
412 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
415 MaxPressure
= max(MaxPressure
, CurPressure
);
418 bool GCNDownwardRPTracker::advance() {
419 // If we have just called reset live set is actual.
420 if ((NextMI
== MBBEnd
) || (LastTrackedMI
&& !advanceBeforeNext()))
426 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End
) {
427 while (NextMI
!= End
)
428 if (!advance()) return false;
432 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin
,
433 MachineBasicBlock::const_iterator End
,
434 const LiveRegSet
*LiveRegsCopy
) {
435 reset(*Begin
, LiveRegsCopy
);
439 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
441 static void reportMismatch(const GCNRPTracker::LiveRegSet
&LISLR
,
442 const GCNRPTracker::LiveRegSet
&TrackedLR
,
443 const TargetRegisterInfo
*TRI
) {
444 for (auto const &P
: TrackedLR
) {
445 auto I
= LISLR
.find(P
.first
);
446 if (I
== LISLR
.end()) {
447 dbgs() << " " << printReg(P
.first
, TRI
)
448 << ":L" << PrintLaneMask(P
.second
)
449 << " isn't found in LIS reported set\n";
451 else if (I
->second
!= P
.second
) {
452 dbgs() << " " << printReg(P
.first
, TRI
)
453 << " masks doesn't match: LIS reported "
454 << PrintLaneMask(I
->second
)
456 << PrintLaneMask(P
.second
)
460 for (auto const &P
: LISLR
) {
461 auto I
= TrackedLR
.find(P
.first
);
462 if (I
== TrackedLR
.end()) {
463 dbgs() << " " << printReg(P
.first
, TRI
)
464 << ":L" << PrintLaneMask(P
.second
)
465 << " isn't found in tracked set\n";
470 bool GCNUpwardRPTracker::isValid() const {
471 const auto &SI
= LIS
.getInstructionIndex(*LastTrackedMI
).getBaseIndex();
472 const auto LISLR
= llvm::getLiveRegs(SI
, LIS
, *MRI
);
473 const auto &TrackedLR
= LiveRegs
;
475 if (!isEqual(LISLR
, TrackedLR
)) {
476 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
477 " LIS reported livesets mismatch:\n";
478 printLivesAt(SI
, LIS
, *MRI
);
479 reportMismatch(LISLR
, TrackedLR
, MRI
->getTargetRegisterInfo());
483 auto LISPressure
= getRegPressure(*MRI
, LISLR
);
484 if (LISPressure
!= CurPressure
) {
485 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
486 CurPressure
.print(dbgs());
487 dbgs() << "LIS rpt: ";
488 LISPressure
.print(dbgs());
494 void GCNRPTracker::printLiveRegs(raw_ostream
&OS
, const LiveRegSet
& LiveRegs
,
495 const MachineRegisterInfo
&MRI
) {
496 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
497 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
498 unsigned Reg
= TargetRegisterInfo::index2VirtReg(I
);
499 auto It
= LiveRegs
.find(Reg
);
500 if (It
!= LiveRegs
.end() && It
->second
.any())
501 OS
<< ' ' << printVRegOrUnit(Reg
, TRI
) << ':'
502 << PrintLaneMask(It
->second
);