1 //===- GCNRegPressure.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This file implements the GCNRegPressure class.
12 //===----------------------------------------------------------------------===//
14 #include "GCNRegPressure.h"
16 #include "llvm/CodeGen/RegisterPressure.h"
20 #define DEBUG_TYPE "machine-scheduler"
22 bool llvm::isEqual(const GCNRPTracker::LiveRegSet
&S1
,
23 const GCNRPTracker::LiveRegSet
&S2
) {
24 if (S1
.size() != S2
.size())
27 for (const auto &P
: S1
) {
28 auto I
= S2
.find(P
.first
);
29 if (I
== S2
.end() || I
->second
!= P
.second
)
35 ///////////////////////////////////////////////////////////////////////////////
38 unsigned GCNRegPressure::getRegKind(Register Reg
,
39 const MachineRegisterInfo
&MRI
) {
40 assert(Reg
.isVirtual());
41 const auto RC
= MRI
.getRegClass(Reg
);
42 auto STI
= static_cast<const SIRegisterInfo
*>(MRI
.getTargetRegisterInfo());
43 return STI
->isSGPRClass(RC
)
44 ? (STI
->getRegSizeInBits(*RC
) == 32 ? SGPR32
: SGPR_TUPLE
)
45 : STI
->isAGPRClass(RC
)
46 ? (STI
->getRegSizeInBits(*RC
) == 32 ? AGPR32
: AGPR_TUPLE
)
47 : (STI
->getRegSizeInBits(*RC
) == 32 ? VGPR32
: VGPR_TUPLE
);
50 void GCNRegPressure::inc(unsigned Reg
,
53 const MachineRegisterInfo
&MRI
) {
54 if (SIRegisterInfo::getNumCoveredRegs(NewMask
) ==
55 SIRegisterInfo::getNumCoveredRegs(PrevMask
))
59 if (NewMask
< PrevMask
) {
60 std::swap(NewMask
, PrevMask
);
64 switch (auto Kind
= getRegKind(Reg
, MRI
)) {
74 assert(PrevMask
< NewMask
);
76 Value
[Kind
== SGPR_TUPLE
? SGPR32
: Kind
== AGPR_TUPLE
? AGPR32
: VGPR32
] +=
77 Sign
* SIRegisterInfo::getNumCoveredRegs(~PrevMask
& NewMask
);
79 if (PrevMask
.none()) {
80 assert(NewMask
.any());
81 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
83 Sign
* TRI
->getRegClassWeight(MRI
.getRegClass(Reg
)).RegWeight
;
87 default: llvm_unreachable("Unknown register kind");
91 bool GCNRegPressure::less(const MachineFunction
&MF
, const GCNRegPressure
&O
,
92 unsigned MaxOccupancy
) const {
93 const GCNSubtarget
&ST
= MF
.getSubtarget
<GCNSubtarget
>();
95 const auto SGPROcc
= std::min(MaxOccupancy
,
96 ST
.getOccupancyWithNumSGPRs(getSGPRNum()));
98 std::min(MaxOccupancy
,
99 ST
.getOccupancyWithNumVGPRs(getVGPRNum(ST
.hasGFX90AInsts())));
100 const auto OtherSGPROcc
= std::min(MaxOccupancy
,
101 ST
.getOccupancyWithNumSGPRs(O
.getSGPRNum()));
102 const auto OtherVGPROcc
=
103 std::min(MaxOccupancy
,
104 ST
.getOccupancyWithNumVGPRs(O
.getVGPRNum(ST
.hasGFX90AInsts())));
106 const auto Occ
= std::min(SGPROcc
, VGPROcc
);
107 const auto OtherOcc
= std::min(OtherSGPROcc
, OtherVGPROcc
);
109 // Give first precedence to the better occupancy.
111 return Occ
> OtherOcc
;
113 unsigned MaxVGPRs
= ST
.getMaxNumVGPRs(MF
);
114 unsigned MaxSGPRs
= ST
.getMaxNumSGPRs(MF
);
116 // SGPR excess pressure conditions
117 unsigned ExcessSGPR
= std::max(static_cast<int>(getSGPRNum() - MaxSGPRs
), 0);
118 unsigned OtherExcessSGPR
=
119 std::max(static_cast<int>(O
.getSGPRNum() - MaxSGPRs
), 0);
121 auto WaveSize
= ST
.getWavefrontSize();
122 // The number of virtual VGPRs required to handle excess SGPR
123 unsigned VGPRForSGPRSpills
= (ExcessSGPR
+ (WaveSize
- 1)) / WaveSize
;
124 unsigned OtherVGPRForSGPRSpills
=
125 (OtherExcessSGPR
+ (WaveSize
- 1)) / WaveSize
;
127 unsigned MaxArchVGPRs
= ST
.getAddressableNumArchVGPRs();
129 // Unified excess pressure conditions, accounting for VGPRs used for SGPR
131 unsigned ExcessVGPR
=
132 std::max(static_cast<int>(getVGPRNum(ST
.hasGFX90AInsts()) +
133 VGPRForSGPRSpills
- MaxVGPRs
),
135 unsigned OtherExcessVGPR
=
136 std::max(static_cast<int>(O
.getVGPRNum(ST
.hasGFX90AInsts()) +
137 OtherVGPRForSGPRSpills
- MaxVGPRs
),
139 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
141 unsigned ExcessArchVGPR
= std::max(
142 static_cast<int>(getVGPRNum(false) + VGPRForSGPRSpills
- MaxArchVGPRs
),
144 unsigned OtherExcessArchVGPR
=
145 std::max(static_cast<int>(O
.getVGPRNum(false) + OtherVGPRForSGPRSpills
-
148 // AGPR excess pressure conditions
149 unsigned ExcessAGPR
= std::max(
150 static_cast<int>(ST
.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs
)
151 : (getAGPRNum() - MaxVGPRs
)),
153 unsigned OtherExcessAGPR
= std::max(
154 static_cast<int>(ST
.hasGFX90AInsts() ? (O
.getAGPRNum() - MaxArchVGPRs
)
155 : (O
.getAGPRNum() - MaxVGPRs
)),
158 bool ExcessRP
= ExcessSGPR
|| ExcessVGPR
|| ExcessArchVGPR
|| ExcessAGPR
;
159 bool OtherExcessRP
= OtherExcessSGPR
|| OtherExcessVGPR
||
160 OtherExcessArchVGPR
|| OtherExcessAGPR
;
162 // Give second precedence to the reduced number of spills to hold the register
164 if (ExcessRP
|| OtherExcessRP
) {
165 // The difference in excess VGPR pressure, after including VGPRs used for
167 int VGPRDiff
= ((OtherExcessVGPR
+ OtherExcessArchVGPR
+ OtherExcessAGPR
) -
168 (ExcessVGPR
+ ExcessArchVGPR
+ ExcessAGPR
));
170 int SGPRDiff
= OtherExcessSGPR
- ExcessSGPR
;
175 unsigned PureExcessVGPR
=
176 std::max(static_cast<int>(getVGPRNum(ST
.hasGFX90AInsts()) - MaxVGPRs
),
178 std::max(static_cast<int>(getVGPRNum(false) - MaxArchVGPRs
), 0);
179 unsigned OtherPureExcessVGPR
=
181 static_cast<int>(O
.getVGPRNum(ST
.hasGFX90AInsts()) - MaxVGPRs
),
183 std::max(static_cast<int>(O
.getVGPRNum(false) - MaxArchVGPRs
), 0);
185 // If we have a special case where there is a tie in excess VGPR, but one
186 // of the pressures has VGPR usage from SGPR spills, prefer the pressure
188 if (PureExcessVGPR
!= OtherPureExcessVGPR
)
190 // If both pressures have the same excess pressure before and after
191 // accounting for SGPR spills, prefer fewer SGPR spills.
196 bool SGPRImportant
= SGPROcc
< VGPROcc
;
197 const bool OtherSGPRImportant
= OtherSGPROcc
< OtherVGPROcc
;
199 // If both pressures disagree on what is more important compare vgprs.
200 if (SGPRImportant
!= OtherSGPRImportant
) {
201 SGPRImportant
= false;
204 // Give third precedence to lower register tuple pressure.
205 bool SGPRFirst
= SGPRImportant
;
206 for (int I
= 2; I
> 0; --I
, SGPRFirst
= !SGPRFirst
) {
208 auto SW
= getSGPRTuplesWeight();
209 auto OtherSW
= O
.getSGPRTuplesWeight();
213 auto VW
= getVGPRTuplesWeight();
214 auto OtherVW
= O
.getVGPRTuplesWeight();
220 // Give final precedence to lower general RP.
221 return SGPRImportant
? (getSGPRNum() < O
.getSGPRNum()):
222 (getVGPRNum(ST
.hasGFX90AInsts()) <
223 O
.getVGPRNum(ST
.hasGFX90AInsts()));
226 Printable
llvm::print(const GCNRegPressure
&RP
, const GCNSubtarget
*ST
) {
227 return Printable([&RP
, ST
](raw_ostream
&OS
) {
228 OS
<< "VGPRs: " << RP
.Value
[GCNRegPressure::VGPR32
] << ' '
229 << "AGPRs: " << RP
.getAGPRNum();
232 << ST
->getOccupancyWithNumVGPRs(RP
.getVGPRNum(ST
->hasGFX90AInsts()))
234 OS
<< ", SGPRs: " << RP
.getSGPRNum();
236 OS
<< "(O" << ST
->getOccupancyWithNumSGPRs(RP
.getSGPRNum()) << ')';
237 OS
<< ", LVGPR WT: " << RP
.getVGPRTuplesWeight()
238 << ", LSGPR WT: " << RP
.getSGPRTuplesWeight();
240 OS
<< " -> Occ: " << RP
.getOccupancy(*ST
);
245 static LaneBitmask
getDefRegMask(const MachineOperand
&MO
,
246 const MachineRegisterInfo
&MRI
) {
247 assert(MO
.isDef() && MO
.isReg() && MO
.getReg().isVirtual());
249 // We don't rely on read-undef flag because in case of tentative schedule
250 // tracking it isn't set correctly yet. This works correctly however since
251 // use mask has been tracked before using LIS.
252 return MO
.getSubReg() == 0 ?
253 MRI
.getMaxLaneMaskForVReg(MO
.getReg()) :
254 MRI
.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO
.getSubReg());
258 collectVirtualRegUses(SmallVectorImpl
<RegisterMaskPair
> &RegMaskPairs
,
259 const MachineInstr
&MI
, const LiveIntervals
&LIS
,
260 const MachineRegisterInfo
&MRI
) {
262 for (const auto &MO
: MI
.operands()) {
263 if (!MO
.isReg() || !MO
.getReg().isVirtual())
265 if (!MO
.isUse() || !MO
.readsReg())
268 Register Reg
= MO
.getReg();
269 if (llvm::any_of(RegMaskPairs
, [Reg
](const RegisterMaskPair
&RM
) {
270 return RM
.RegUnit
== Reg
;
275 auto &LI
= LIS
.getInterval(Reg
);
276 if (!LI
.hasSubRanges())
277 UseMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
279 // For a tentative schedule LIS isn't updated yet but livemask should
280 // remain the same on any schedule. Subreg defs can be reordered but they
281 // all must dominate uses anyway.
283 InstrSI
= LIS
.getInstructionIndex(*MO
.getParent()).getBaseIndex();
284 UseMask
= getLiveLaneMask(LI
, InstrSI
, MRI
);
287 RegMaskPairs
.emplace_back(Reg
, UseMask
);
291 ///////////////////////////////////////////////////////////////////////////////
294 LaneBitmask
llvm::getLiveLaneMask(unsigned Reg
, SlotIndex SI
,
295 const LiveIntervals
&LIS
,
296 const MachineRegisterInfo
&MRI
) {
297 return getLiveLaneMask(LIS
.getInterval(Reg
), SI
, MRI
);
300 LaneBitmask
llvm::getLiveLaneMask(const LiveInterval
&LI
, SlotIndex SI
,
301 const MachineRegisterInfo
&MRI
) {
302 LaneBitmask LiveMask
;
303 if (LI
.hasSubRanges()) {
304 for (const auto &S
: LI
.subranges())
306 LiveMask
|= S
.LaneMask
;
307 assert(LiveMask
== (LiveMask
& MRI
.getMaxLaneMaskForVReg(LI
.reg())));
309 } else if (LI
.liveAt(SI
)) {
310 LiveMask
= MRI
.getMaxLaneMaskForVReg(LI
.reg());
315 GCNRPTracker::LiveRegSet
llvm::getLiveRegs(SlotIndex SI
,
316 const LiveIntervals
&LIS
,
317 const MachineRegisterInfo
&MRI
) {
318 GCNRPTracker::LiveRegSet LiveRegs
;
319 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
320 auto Reg
= Register::index2VirtReg(I
);
321 if (!LIS
.hasInterval(Reg
))
323 auto LiveMask
= getLiveLaneMask(Reg
, SI
, LIS
, MRI
);
325 LiveRegs
[Reg
] = LiveMask
;
330 void GCNRPTracker::reset(const MachineInstr
&MI
,
331 const LiveRegSet
*LiveRegsCopy
,
333 const MachineFunction
&MF
= *MI
.getMF();
334 MRI
= &MF
.getRegInfo();
336 if (&LiveRegs
!= LiveRegsCopy
)
337 LiveRegs
= *LiveRegsCopy
;
339 LiveRegs
= After
? getLiveRegsAfter(MI
, LIS
)
340 : getLiveRegsBefore(MI
, LIS
);
343 MaxPressure
= CurPressure
= getRegPressure(*MRI
, LiveRegs
);
346 ////////////////////////////////////////////////////////////////////////////////
347 // GCNUpwardRPTracker
349 void GCNUpwardRPTracker::reset(const MachineRegisterInfo
&MRI_
,
350 const LiveRegSet
&LiveRegs_
) {
352 LiveRegs
= LiveRegs_
;
353 LastTrackedMI
= nullptr;
354 MaxPressure
= CurPressure
= getRegPressure(MRI_
, LiveRegs_
);
357 void GCNUpwardRPTracker::recede(const MachineInstr
&MI
) {
358 assert(MRI
&& "call reset first");
362 if (MI
.isDebugInstr())
366 GCNRegPressure DefPressure
, ECDefPressure
;
367 bool HasECDefs
= false;
368 for (const MachineOperand
&MO
: MI
.all_defs()) {
369 if (!MO
.getReg().isVirtual())
372 Register Reg
= MO
.getReg();
373 LaneBitmask DefMask
= getDefRegMask(MO
, *MRI
);
375 // Treat a def as fully live at the moment of definition: keep a record.
376 if (MO
.isEarlyClobber()) {
377 ECDefPressure
.inc(Reg
, LaneBitmask::getNone(), DefMask
, *MRI
);
380 DefPressure
.inc(Reg
, LaneBitmask::getNone(), DefMask
, *MRI
);
382 auto I
= LiveRegs
.find(Reg
);
383 if (I
== LiveRegs
.end())
386 LaneBitmask
&LiveMask
= I
->second
;
387 LaneBitmask PrevMask
= LiveMask
;
388 LiveMask
&= ~DefMask
;
389 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
394 // Update MaxPressure with defs pressure.
395 DefPressure
+= CurPressure
;
397 DefPressure
+= ECDefPressure
;
398 MaxPressure
= max(DefPressure
, MaxPressure
);
401 SmallVector
<RegisterMaskPair
, 8> RegUses
;
402 collectVirtualRegUses(RegUses
, MI
, LIS
, *MRI
);
403 for (const RegisterMaskPair
&U
: RegUses
) {
404 LaneBitmask
&LiveMask
= LiveRegs
[U
.RegUnit
];
405 LaneBitmask PrevMask
= LiveMask
;
406 LiveMask
|= U
.LaneMask
;
407 CurPressure
.inc(U
.RegUnit
, PrevMask
, LiveMask
, *MRI
);
410 // Update MaxPressure with uses plus early-clobber defs pressure.
411 MaxPressure
= HasECDefs
? max(CurPressure
+ ECDefPressure
, MaxPressure
)
412 : max(CurPressure
, MaxPressure
);
414 assert(CurPressure
== getRegPressure(*MRI
, LiveRegs
));
417 ////////////////////////////////////////////////////////////////////////////////
418 // GCNDownwardRPTracker
420 bool GCNDownwardRPTracker::reset(const MachineInstr
&MI
,
421 const LiveRegSet
*LiveRegsCopy
) {
422 MRI
= &MI
.getParent()->getParent()->getRegInfo();
423 LastTrackedMI
= nullptr;
424 MBBEnd
= MI
.getParent()->end();
426 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
427 if (NextMI
== MBBEnd
)
429 GCNRPTracker::reset(*NextMI
, LiveRegsCopy
, false);
433 bool GCNDownwardRPTracker::advanceBeforeNext() {
434 assert(MRI
&& "call reset first");
436 return NextMI
== MBBEnd
;
438 assert(NextMI
== MBBEnd
|| !NextMI
->isDebugInstr());
440 SlotIndex SI
= NextMI
== MBBEnd
441 ? LIS
.getInstructionIndex(*LastTrackedMI
).getDeadSlot()
442 : LIS
.getInstructionIndex(*NextMI
).getBaseIndex();
443 assert(SI
.isValid());
445 // Remove dead registers or mask bits.
446 SmallSet
<Register
, 8> SeenRegs
;
447 for (auto &MO
: LastTrackedMI
->operands()) {
448 if (!MO
.isReg() || !MO
.getReg().isVirtual())
450 if (MO
.isUse() && !MO
.readsReg())
452 if (!SeenRegs
.insert(MO
.getReg()).second
)
454 const LiveInterval
&LI
= LIS
.getInterval(MO
.getReg());
455 if (LI
.hasSubRanges()) {
456 auto It
= LiveRegs
.end();
457 for (const auto &S
: LI
.subranges()) {
459 if (It
== LiveRegs
.end()) {
460 It
= LiveRegs
.find(MO
.getReg());
461 if (It
== LiveRegs
.end())
462 llvm_unreachable("register isn't live");
464 auto PrevMask
= It
->second
;
465 It
->second
&= ~S
.LaneMask
;
466 CurPressure
.inc(MO
.getReg(), PrevMask
, It
->second
, *MRI
);
469 if (It
!= LiveRegs
.end() && It
->second
.none())
471 } else if (!LI
.liveAt(SI
)) {
472 auto It
= LiveRegs
.find(MO
.getReg());
473 if (It
== LiveRegs
.end())
474 llvm_unreachable("register isn't live");
475 CurPressure
.inc(MO
.getReg(), It
->second
, LaneBitmask::getNone(), *MRI
);
480 MaxPressure
= max(MaxPressure
, CurPressure
);
482 LastTrackedMI
= nullptr;
484 return NextMI
== MBBEnd
;
487 void GCNDownwardRPTracker::advanceToNext() {
488 LastTrackedMI
= &*NextMI
++;
489 NextMI
= skipDebugInstructionsForward(NextMI
, MBBEnd
);
491 // Add new registers or mask bits.
492 for (const auto &MO
: LastTrackedMI
->all_defs()) {
493 Register Reg
= MO
.getReg();
494 if (!Reg
.isVirtual())
496 auto &LiveMask
= LiveRegs
[Reg
];
497 auto PrevMask
= LiveMask
;
498 LiveMask
|= getDefRegMask(MO
, *MRI
);
499 CurPressure
.inc(Reg
, PrevMask
, LiveMask
, *MRI
);
502 MaxPressure
= max(MaxPressure
, CurPressure
);
505 bool GCNDownwardRPTracker::advance() {
506 if (NextMI
== MBBEnd
)
513 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End
) {
514 while (NextMI
!= End
)
515 if (!advance()) return false;
519 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin
,
520 MachineBasicBlock::const_iterator End
,
521 const LiveRegSet
*LiveRegsCopy
) {
522 reset(*Begin
, LiveRegsCopy
);
526 Printable
llvm::reportMismatch(const GCNRPTracker::LiveRegSet
&LISLR
,
527 const GCNRPTracker::LiveRegSet
&TrackedLR
,
528 const TargetRegisterInfo
*TRI
, StringRef Pfx
) {
529 return Printable([&LISLR
, &TrackedLR
, TRI
, Pfx
](raw_ostream
&OS
) {
530 for (auto const &P
: TrackedLR
) {
531 auto I
= LISLR
.find(P
.first
);
532 if (I
== LISLR
.end()) {
533 OS
<< Pfx
<< printReg(P
.first
, TRI
) << ":L" << PrintLaneMask(P
.second
)
534 << " isn't found in LIS reported set\n";
535 } else if (I
->second
!= P
.second
) {
536 OS
<< Pfx
<< printReg(P
.first
, TRI
)
537 << " masks doesn't match: LIS reported " << PrintLaneMask(I
->second
)
538 << ", tracked " << PrintLaneMask(P
.second
) << '\n';
541 for (auto const &P
: LISLR
) {
542 auto I
= TrackedLR
.find(P
.first
);
543 if (I
== TrackedLR
.end()) {
544 OS
<< Pfx
<< printReg(P
.first
, TRI
) << ":L" << PrintLaneMask(P
.second
)
545 << " isn't found in tracked set\n";
551 bool GCNUpwardRPTracker::isValid() const {
552 const auto &SI
= LIS
.getInstructionIndex(*LastTrackedMI
).getBaseIndex();
553 const auto LISLR
= llvm::getLiveRegs(SI
, LIS
, *MRI
);
554 const auto &TrackedLR
= LiveRegs
;
556 if (!isEqual(LISLR
, TrackedLR
)) {
557 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
558 " LIS reported livesets mismatch:\n"
559 << print(LISLR
, *MRI
);
560 reportMismatch(LISLR
, TrackedLR
, MRI
->getTargetRegisterInfo());
564 auto LISPressure
= getRegPressure(*MRI
, LISLR
);
565 if (LISPressure
!= CurPressure
) {
566 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
567 << print(CurPressure
) << "LIS rpt: " << print(LISPressure
);
573 Printable
llvm::print(const GCNRPTracker::LiveRegSet
&LiveRegs
,
574 const MachineRegisterInfo
&MRI
) {
575 return Printable([&LiveRegs
, &MRI
](raw_ostream
&OS
) {
576 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
577 for (unsigned I
= 0, E
= MRI
.getNumVirtRegs(); I
!= E
; ++I
) {
578 Register Reg
= Register::index2VirtReg(I
);
579 auto It
= LiveRegs
.find(Reg
);
580 if (It
!= LiveRegs
.end() && It
->second
.any())
581 OS
<< ' ' << printVRegOrUnit(Reg
, TRI
) << ':'
582 << PrintLaneMask(It
->second
);
588 void GCNRegPressure::dump() const { dbgs() << print(*this); }
590 static cl::opt
<bool> UseDownwardTracker(
591 "amdgpu-print-rp-downward",
592 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
593 cl::init(false), cl::Hidden
);
595 char llvm::GCNRegPressurePrinter::ID
= 0;
596 char &llvm::GCNRegPressurePrinterID
= GCNRegPressurePrinter::ID
;
598 INITIALIZE_PASS(GCNRegPressurePrinter
, "amdgpu-print-rp", "", true, true)
600 // Return lanemask of Reg's subregs that are live-through at [Begin, End] and
601 // are fully covered by Mask.
603 getRegLiveThroughMask(const MachineRegisterInfo
&MRI
, const LiveIntervals
&LIS
,
604 Register Reg
, SlotIndex Begin
, SlotIndex End
,
605 LaneBitmask Mask
= LaneBitmask::getAll()) {
607 auto IsInOneSegment
= [Begin
, End
](const LiveRange
&LR
) -> bool {
608 auto *Segment
= LR
.getSegmentContaining(Begin
);
609 return Segment
&& Segment
->contains(End
);
612 LaneBitmask LiveThroughMask
;
613 const LiveInterval
&LI
= LIS
.getInterval(Reg
);
614 if (LI
.hasSubRanges()) {
615 for (auto &SR
: LI
.subranges()) {
616 if ((SR
.LaneMask
& Mask
) == SR
.LaneMask
&& IsInOneSegment(SR
))
617 LiveThroughMask
|= SR
.LaneMask
;
620 LaneBitmask RegMask
= MRI
.getMaxLaneMaskForVReg(Reg
);
621 if ((RegMask
& Mask
) == RegMask
&& IsInOneSegment(LI
))
622 LiveThroughMask
= RegMask
;
625 return LiveThroughMask
;
628 bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction
&MF
) {
629 const MachineRegisterInfo
&MRI
= MF
.getRegInfo();
630 const TargetRegisterInfo
*TRI
= MRI
.getTargetRegisterInfo();
631 const LiveIntervals
&LIS
= getAnalysis
<LiveIntervalsWrapperPass
>().getLIS();
635 // Leading spaces are important for YAML syntax.
638 OS
<< "---\nname: " << MF
.getName() << "\nbody: |\n";
640 auto printRP
= [](const GCNRegPressure
&RP
) {
641 return Printable([&RP
](raw_ostream
&OS
) {
642 OS
<< format(PFX
" %-5d", RP
.getSGPRNum())
643 << format(" %-5d", RP
.getVGPRNum(false));
647 auto ReportLISMismatchIfAny
= [&](const GCNRPTracker::LiveRegSet
&TrackedLR
,
648 const GCNRPTracker::LiveRegSet
&LISLR
) {
649 if (LISLR
!= TrackedLR
) {
650 OS
<< PFX
" mis LIS: " << llvm::print(LISLR
, MRI
)
651 << reportMismatch(LISLR
, TrackedLR
, TRI
, PFX
" ");
655 // Register pressure before and at an instruction (in program order).
656 SmallVector
<std::pair
<GCNRegPressure
, GCNRegPressure
>, 16> RP
;
658 for (auto &MBB
: MF
) {
660 RP
.reserve(MBB
.size());
666 SlotIndex MBBStartSlot
= LIS
.getSlotIndexes()->getMBBStartIdx(&MBB
);
667 SlotIndex MBBEndSlot
= LIS
.getSlotIndexes()->getMBBEndIdx(&MBB
);
669 GCNRPTracker::LiveRegSet LiveIn
, LiveOut
;
670 GCNRegPressure RPAtMBBEnd
;
672 if (UseDownwardTracker
) {
674 LiveIn
= LiveOut
= getLiveRegs(MBBStartSlot
, LIS
, MRI
);
675 RPAtMBBEnd
= getRegPressure(MRI
, LiveIn
);
677 GCNDownwardRPTracker
RPT(LIS
);
678 RPT
.reset(MBB
.front());
680 LiveIn
= RPT
.getLiveRegs();
682 while (!RPT
.advanceBeforeNext()) {
683 GCNRegPressure RPBeforeMI
= RPT
.getPressure();
685 RP
.emplace_back(RPBeforeMI
, RPT
.getPressure());
688 LiveOut
= RPT
.getLiveRegs();
689 RPAtMBBEnd
= RPT
.getPressure();
692 GCNUpwardRPTracker
RPT(LIS
);
693 RPT
.reset(MRI
, MBBEndSlot
);
695 LiveOut
= RPT
.getLiveRegs();
696 RPAtMBBEnd
= RPT
.getPressure();
698 for (auto &MI
: reverse(MBB
)) {
699 RPT
.resetMaxPressure();
701 if (!MI
.isDebugInstr())
702 RP
.emplace_back(RPT
.getPressure(), RPT
.getMaxPressure());
705 LiveIn
= RPT
.getLiveRegs();
708 OS
<< PFX
" Live-in: " << llvm::print(LiveIn
, MRI
);
709 if (!UseDownwardTracker
)
710 ReportLISMismatchIfAny(LiveIn
, getLiveRegs(MBBStartSlot
, LIS
, MRI
));
712 OS
<< PFX
" SGPR VGPR\n";
714 for (auto &MI
: MBB
) {
715 if (!MI
.isDebugInstr()) {
716 auto &[RPBeforeInstr
, RPAtInstr
] =
717 RP
[UseDownwardTracker
? I
: (RP
.size() - 1 - I
)];
719 OS
<< printRP(RPBeforeInstr
) << '\n' << printRP(RPAtInstr
) << " ";
724 OS
<< printRP(RPAtMBBEnd
) << '\n';
726 OS
<< PFX
" Live-out:" << llvm::print(LiveOut
, MRI
);
727 if (UseDownwardTracker
)
728 ReportLISMismatchIfAny(LiveOut
, getLiveRegs(MBBEndSlot
, LIS
, MRI
));
730 GCNRPTracker::LiveRegSet LiveThrough
;
731 for (auto [Reg
, Mask
] : LiveIn
) {
732 LaneBitmask MaskIntersection
= Mask
& LiveOut
.lookup(Reg
);
733 if (MaskIntersection
.any()) {
734 LaneBitmask LTMask
= getRegLiveThroughMask(
735 MRI
, LIS
, Reg
, MBBStartSlot
, MBBEndSlot
, MaskIntersection
);
737 LiveThrough
[Reg
] = LTMask
;
740 OS
<< PFX
" Live-thr:" << llvm::print(LiveThrough
, MRI
);
741 OS
<< printRP(getRegPressure(MRI
, LiveThrough
)) << '\n';