1 //===- GCNRegPressure.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
32 #define DEBUG_TYPE "machine-scheduler"
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
36 void llvm::printLivesAt(SlotIndex SI
,
37 const LiveIntervals
&LIS
,
38 const MachineRegisterInfo
&MRI
) {
39 dbgs() << "Live regs at " << SI
<< ": "
40 << *LIS
.getInstructionFromIndex(SI
);
42 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
43 const unsigned Reg
= TargetRegisterInfo::index2VirtReg(I
);
44 if (!LIS
.hasInterval(Reg
))
46 const auto &LI
= LIS
.getInterval(Reg
);
47 if (LI
.hasSubRanges()) {
48 bool firstTime
= true;
49 for (const auto &S
: LI
.subranges()) {
50 if (!S
.liveAt(SI
)) continue;
52 dbgs() << " " << printReg(Reg
, MRI
.getTargetRegisterInfo())
56 dbgs() << " " << S
<< '\n';
59 } else if (LI
.liveAt(SI
)) {
60 dbgs() << " " << LI
<< '\n';
64 if (!Num
) dbgs() << " <none>\n";
68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet
&S1
,
69 const GCNRPTracker::LiveRegSet
&S2
) {
70 if (S1
.size() != S2
.size())
73 for (const auto &P
: S1
) {
74 auto I
= S2
.find(P
.first
);
75 if (I
== S2
.end() || I
->second
!= P
.second
)
82 ///////////////////////////////////////////////////////////////////////////////
85 unsigned GCNRegPressure::getRegKind(unsigned Reg
,
86 const MachineRegisterInfo
&MRI
) {
87 assert(TargetRegisterInfo::isVirtualRegister(Reg
));
88 const auto RC
= MRI
.getRegClass(Reg
);
89 auto STI
= static_cast<const SIRegisterInfo
*>(MRI
.getTargetRegisterInfo());
90 return STI
->isSGPRClass(RC
) ?
91 (STI
->getRegSizeInBits(*RC
) == 32 ? SGPR32
: SGPR_TUPLE
) :
93 (STI
->getRegSizeInBits(*RC
) == 32 ? AGPR32
: AGPR_TUPLE
) :
94 (STI
->getRegSizeInBits(*RC
) == 32 ? VGPR32
: VGPR_TUPLE
);
97 void GCNRegPressure::inc(unsigned Reg
,
100 const MachineRegisterInfo
&MRI
) {
101 if (NewMask
== PrevMask
)
105 if (NewMask
< PrevMask
) {
106 std::swap(NewMask
, PrevMask
);
110 const auto MaxMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
112 switch (auto Kind
= getRegKind(Reg
, MRI
)) {
116 assert(PrevMask
.none() && NewMask
== MaxMask
);
123 assert(NewMask
< MaxMask
|| NewMask
== MaxMask
);
124 assert(PrevMask
< NewMask
);
126 Value
[Kind
== SGPR_TUPLE
? SGPR32
: Kind
== AGPR_TUPLE
? AGPR32
: VGPR32
] +=
127 Sign
* (~PrevMask
& NewMask
).getNumLanes();
129 if (PrevMask
.none()) {
130 assert(NewMask
.any());
131 Value
[Kind
] += Sign
* MRI
.getPressureSets(Reg
).getWeight();
135 default: llvm_unreachable("Unknown register kind");
139 bool GCNRegPressure::less(const GCNSubtarget
&ST
,
140 const GCNRegPressure
& O
,
141 unsigned MaxOccupancy
) const {
142 const auto SGPROcc
= std::min(MaxOccupancy
,
143 ST
.getOccupancyWithNumSGPRs(getSGPRNum()));
144 const auto VGPROcc
= std::min(MaxOccupancy
,
145 ST
.getOccupancyWithNumVGPRs(getVGPRNum()));
146 const auto OtherSGPROcc
= std::min(MaxOccupancy
,
147 ST
.getOccupancyWithNumSGPRs(O
.getSGPRNum()));
148 const auto OtherVGPROcc
= std::min(MaxOccupancy
,
149 ST
.getOccupancyWithNumVGPRs(O
.getVGPRNum()));
151 const auto Occ
= std::min(SGPROcc
, VGPROcc
);
152 const auto OtherOcc
= std::min(OtherSGPROcc
, OtherVGPROcc
);
154 return Occ
> OtherOcc
;
156 bool SGPRImportant
= SGPROcc
< VGPROcc
;
157 const bool OtherSGPRImportant
= OtherSGPROcc
< OtherVGPROcc
;
159 // if both pressures disagree on what is more important compare vgprs
160 if (SGPRImportant
!= OtherSGPRImportant
) {
161 SGPRImportant
= false;
164 // compare large regs pressure
165 bool SGPRFirst
= SGPRImportant
;
166 for (int I
= 2; I
> 0; --I
, SGPRFirst
= !SGPRFirst
) {
168 auto SW
= getSGPRTuplesWeight();
169 auto OtherSW
= O
.getSGPRTuplesWeight();
173 auto VW
= getVGPRTuplesWeight();
174 auto OtherVW
= O
.getVGPRTuplesWeight();
179 return SGPRImportant
? (getSGPRNum() < O
.getSGPRNum()):
180 (getVGPRNum() < O
.getVGPRNum());
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
185 void GCNRegPressure::print(raw_ostream
&OS
, const GCNSubtarget
*ST
) const {
186 OS
<< "VGPRs: " << getVGPRNum();
187 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
188 OS
<< ", SGPRs: " << getSGPRNum();
189 if (ST
) OS
<< "(O" << ST
->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
190 OS
<< ", LVGPR WT: " << getVGPRTuplesWeight()
191 << ", LSGPR WT: " << getSGPRTuplesWeight();
192 if (ST
) OS
<< " -> Occ: " << getOccupancy(*ST
);
197 static LaneBitmask
getDefRegMask(const MachineOperand
&MO
,
198 const MachineRegisterInfo
&MRI
) {
199 assert(MO
.isDef() && MO
.isReg() &&
200 TargetRegisterInfo::isVirtualRegister(MO
.getReg()));
202 // We don't rely on read-undef flag because in case of tentative schedule
203 // tracking it isn't set correctly yet. This works correctly however since
204 // use mask has been tracked before using LIS.
205 return MO
.getSubReg() == 0 ?
206 MRI
.getMaxLaneMaskForVReg(MO
.getReg()) :
207 MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO
.getSubReg());
210 static LaneBitmask
getUsedRegMask(const MachineOperand
&MO
,
211 const MachineRegisterInfo
&MRI
,
212 const LiveIntervals
&LIS
) {
213 assert(MO
.isUse() && MO
.isReg() &&
214 TargetRegisterInfo::isVirtualRegister(MO
.getReg()));
216 if (auto SubReg
= MO
.getSubReg())
217 return MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg
);
219 auto MaxMask
= MRI
.getMaxLaneMaskForVReg(MO
.getReg());
220 if (MaxMask
== LaneBitmask::getLane(0)) // cannot have subregs
223 // For a tentative schedule LIS isn't updated yet but livemask should remain
224 // the same on any schedule. Subreg defs can be reordered but they all must
225 // dominate uses anyway.
226 auto SI
= LIS
.getInstructionIndex(*MO
.getParent()).getBaseIndex();
227 return getLiveLaneMask(MO
.getReg(), SI
, LIS
, MRI
);
230 static SmallVector
<RegisterMaskPair
, 8>
231 collectVirtualRegUses(const MachineInstr
&MI
, const LiveIntervals
&LIS
,
232 const MachineRegisterInfo
&MRI
) {
233 SmallVector
<RegisterMaskPair
, 8> Res
;
234 for (const auto &MO
: MI
.operands()) {
235 if (!MO
.isReg() || !TargetRegisterInfo::isVirtualRegister(MO
.getReg()))
237 if (!MO
.isUse() || !MO
.readsReg())
240 auto const UsedMask
= getUsedRegMask(MO
, MRI
, LIS
);
242 auto Reg
= MO
.getReg();
243 auto I
= std::find_if(Res
.begin(), Res
.end(), [Reg
](const RegisterMaskPair
&RM
) {
244 return RM
.RegUnit
== Reg
;
247 I
->LaneMask
|= UsedMask
;
249 Res
.push_back(RegisterMaskPair(Reg
, UsedMask
));
254 ///////////////////////////////////////////////////////////////////////////////
257 LaneBitmask
llvm::getLiveLaneMask(unsigned Reg
,
259 const LiveIntervals
&LIS
,
260 const MachineRegisterInfo
&MRI
) {
261 LaneBitmask LiveMask
;
262 const auto &LI
= LIS
.getInterval(Reg
);
263 if (LI
.hasSubRanges()) {
264 for (const auto &S
: LI
.subranges())
266 LiveMask
|= S
.LaneMask
;
267 assert(LiveMask
< MRI
.getMaxLaneMaskForVReg(Reg
) ||
268 LiveMask
== MRI
.getMaxLaneMaskForVReg(Reg
));
270 } else if (LI
.liveAt(SI
)) {
271 LiveMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
276 GCNRPTracker::LiveRegSet
llvm::getLiveRegs(SlotIndex SI
,
277 const LiveIntervals
&LIS
,
278 const MachineRegisterInfo
&MRI
) {
279 GCNRPTracker::LiveRegSet LiveRegs
;
280 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
281 auto Reg
= TargetRegisterInfo::index2VirtReg(I
);
282 if (!LIS
.hasInterval(Reg
))
284 auto LiveMask
= getLiveLaneMask(Reg
, SI
, LIS
, MRI
);
286 LiveRegs
[Reg
] = LiveMask
;
291 void GCNRPTracker::reset(const MachineInstr
&MI
,
292 const LiveRegSet
*LiveRegsCopy
,
294 const MachineFunction
&MF
= *MI
.getMF();
295 MRI
= &MF
.getRegInfo();
297 if (&LiveRegs
!= LiveRegsCopy
)
298 LiveRegs
= *LiveRegsCopy
;
300 LiveRegs
= After
? getLiveRegsAfter(MI
, LIS
)
301 : getLiveRegsBefore(MI
, LIS
);
304 MaxPressure
= CurPressure
= getRegPressure(*MRI
, LiveRegs
);
307 void GCNUpwardRPTracker::reset(const MachineInstr
&MI
,
308 const LiveRegSet
*LiveRegsCopy
) {
309 GCNRPTracker::reset(MI
, LiveRegsCopy
, true);
312 void GCNUpwardRPTracker::recede(const MachineInstr
&MI
) {
313 assert(MRI
&& "call reset first");
317 if (MI
.isDebugInstr())
320 auto const RegUses
= collectVirtualRegUses(MI
, LIS
, *MRI
);
322 // calc pressure at the MI (defs + uses)
323 auto AtMIPressure
= CurPressure
;
324 for (const auto &U
: RegUses
) {
325 auto LiveMask
= LiveRegs
[U
.RegUnit
];
326 AtMIPressure
.inc(U
.RegUnit
, LiveMask
, LiveMask
| U
.LaneMask
, *MRI
);
328 // update max pressure
329 MaxPressure
= max(AtMIPressure
, MaxPressure
);
331 for (const auto &MO
: MI
.defs()) {
332 if (!MO
.isReg() || !TargetRegisterInfo::isVirtualRegister(MO
.getReg()) ||
336 auto Reg
= MO
.getReg();
337 auto I
= LiveRegs
.find(Reg
);
338 if (I
== LiveRegs
.end())
340 auto &LiveMask
= I
->second
;
341 auto PrevMask
= LiveMask
;
342 LiveMask
&= ~getDefRegMask(MO
, *MRI
);
343 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
347 for (const auto &U
: RegUses
) {
348 auto &LiveMask
= LiveRegs
[U
.RegUnit
];
349 auto PrevMask
= LiveMask
;
350 LiveMask
|= U
.LaneMask
;
351 CurPressure
.inc(U
.RegUnit
, PrevMask
, LiveMask
, *MRI
);
353 assert(CurPressure
== getRegPressure(*MRI
, LiveRegs
));
356 bool GCNDownwardRPTracker::reset(const MachineInstr
&MI
,
357 const LiveRegSet
*LiveRegsCopy
) {
358 MRI
= &MI
.getParent()->getParent()->getRegInfo();
359 LastTrackedMI
= nullptr;
360 MBBEnd
= MI
.getParent()->end();
362 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
363 if (NextMI
== MBBEnd
)
365 GCNRPTracker::reset(*NextMI
, LiveRegsCopy
, false);
369 bool GCNDownwardRPTracker::advanceBeforeNext() {
370 assert(MRI
&& "call reset first");
372 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
373 if (NextMI
== MBBEnd
)
376 SlotIndex SI
= LIS
.getInstructionIndex(*NextMI
).getBaseIndex();
377 assert(SI
.isValid());
379 // Remove dead registers or mask bits.
380 for (auto &It
: LiveRegs
) {
381 const LiveInterval
&LI
= LIS
.getInterval(It
.first
);
382 if (LI
.hasSubRanges()) {
383 for (const auto &S
: LI
.subranges()) {
385 auto PrevMask
= It
.second
;
386 It
.second
&= ~S
.LaneMask
;
387 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
390 } else if (!LI
.liveAt(SI
)) {
391 auto PrevMask
= It
.second
;
392 It
.second
= LaneBitmask::getNone();
393 CurPressure
.inc(It
.first
, PrevMask
, It
.second
, *MRI
);
395 if (It
.second
.none())
396 LiveRegs
.erase(It
.first
);
399 MaxPressure
= max(MaxPressure
, CurPressure
);
404 void GCNDownwardRPTracker::advanceToNext() {
405 LastTrackedMI
= &*NextMI
++;
407 // Add new registers or mask bits.
408 for (const auto &MO
: LastTrackedMI
->defs()) {
411 unsigned Reg
= MO
.getReg();
412 if (!TargetRegisterInfo::isVirtualRegister(Reg
))
414 auto &LiveMask
= LiveRegs
[Reg
];
415 auto PrevMask
= LiveMask
;
416 LiveMask
|= getDefRegMask(MO
, *MRI
);
417 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
420 MaxPressure
= max(MaxPressure
, CurPressure
);
423 bool GCNDownwardRPTracker::advance() {
424 // If we have just called reset live set is actual.
425 if ((NextMI
== MBBEnd
) || (LastTrackedMI
&& !advanceBeforeNext()))
431 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End
) {
432 while (NextMI
!= End
)
433 if (!advance()) return false;
437 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin
,
438 MachineBasicBlock::const_iterator End
,
439 const LiveRegSet
*LiveRegsCopy
) {
440 reset(*Begin
, LiveRegsCopy
);
444 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
446 static void reportMismatch(const GCNRPTracker::LiveRegSet
&LISLR
,
447 const GCNRPTracker::LiveRegSet
&TrackedLR
,
448 const TargetRegisterInfo
*TRI
) {
449 for (auto const &P
: TrackedLR
) {
450 auto I
= LISLR
.find(P
.first
);
451 if (I
== LISLR
.end()) {
452 dbgs() << " " << printReg(P
.first
, TRI
)
453 << ":L" << PrintLaneMask(P
.second
)
454 << " isn't found in LIS reported set\n";
456 else if (I
->second
!= P
.second
) {
457 dbgs() << " " << printReg(P
.first
, TRI
)
458 << " masks doesn't match: LIS reported "
459 << PrintLaneMask(I
->second
)
461 << PrintLaneMask(P
.second
)
465 for (auto const &P
: LISLR
) {
466 auto I
= TrackedLR
.find(P
.first
);
467 if (I
== TrackedLR
.end()) {
468 dbgs() << " " << printReg(P
.first
, TRI
)
469 << ":L" << PrintLaneMask(P
.second
)
470 << " isn't found in tracked set\n";
475 bool GCNUpwardRPTracker::isValid() const {
476 const auto &SI
= LIS
.getInstructionIndex(*LastTrackedMI
).getBaseIndex();
477 const auto LISLR
= llvm::getLiveRegs(SI
, LIS
, *MRI
);
478 const auto &TrackedLR
= LiveRegs
;
480 if (!isEqual(LISLR
, TrackedLR
)) {
481 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
482 " LIS reported livesets mismatch:\n";
483 printLivesAt(SI
, LIS
, *MRI
);
484 reportMismatch(LISLR
, TrackedLR
, MRI
->getTargetRegisterInfo());
488 auto LISPressure
= getRegPressure(*MRI
, LISLR
);
489 if (LISPressure
!= CurPressure
) {
490 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
491 CurPressure
.print(dbgs());
492 dbgs() << "LIS rpt: ";
493 LISPressure
.print(dbgs());
499 void GCNRPTracker::printLiveRegs(raw_ostream
&OS
, const LiveRegSet
& LiveRegs
,
500 const MachineRegisterInfo
&MRI
) {
501 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
502 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
503 unsigned Reg
= TargetRegisterInfo::index2VirtReg(I
);
504 auto It
= LiveRegs
.find(Reg
);
505 if (It
!= LiveRegs
.end() && It
->second
.any())
506 OS
<< ' ' << printVRegOrUnit(Reg
, TRI
) << ':'
507 << PrintLaneMask(It
->second
);