1 //===-- GCNSchedStrategy.cpp - GCN Scheduler Strategy ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This contains a MachineSchedStrategy implementation for maximizing wave
11 /// occupancy on GCN hardware.
12 //===----------------------------------------------------------------------===//
14 #include "GCNSchedStrategy.h"
15 #include "AMDGPUSubtarget.h"
16 #include "SIInstrInfo.h"
17 #include "SIMachineFunctionInfo.h"
18 #include "SIRegisterInfo.h"
19 #include "llvm/CodeGen/RegisterClassInfo.h"
20 #include "llvm/Support/MathExtras.h"
22 #define DEBUG_TYPE "machine-scheduler"
26 GCNMaxOccupancySchedStrategy::GCNMaxOccupancySchedStrategy(
27 const MachineSchedContext
*C
) :
28 GenericScheduler(C
), TargetOccupancy(0), MF(nullptr) { }
30 void GCNMaxOccupancySchedStrategy::initialize(ScheduleDAGMI
*DAG
) {
31 GenericScheduler::initialize(DAG
);
33 const SIRegisterInfo
*SRI
= static_cast<const SIRegisterInfo
*>(TRI
);
37 const GCNSubtarget
&ST
= MF
->getSubtarget
<GCNSubtarget
>();
39 // FIXME: This is also necessary, because some passes that run after
40 // scheduling and before regalloc increase register pressure.
41 const int ErrorMargin
= 3;
43 SGPRExcessLimit
= Context
->RegClassInfo
44 ->getNumAllocatableRegs(&AMDGPU::SGPR_32RegClass
) - ErrorMargin
;
45 VGPRExcessLimit
= Context
->RegClassInfo
46 ->getNumAllocatableRegs(&AMDGPU::VGPR_32RegClass
) - ErrorMargin
;
47 if (TargetOccupancy
) {
48 SGPRCriticalLimit
= ST
.getMaxNumSGPRs(TargetOccupancy
, true);
49 VGPRCriticalLimit
= ST
.getMaxNumVGPRs(TargetOccupancy
);
51 SGPRCriticalLimit
= SRI
->getRegPressureSetLimit(DAG
->MF
,
52 SRI
->getSGPRPressureSet());
53 VGPRCriticalLimit
= SRI
->getRegPressureSetLimit(DAG
->MF
,
54 SRI
->getVGPRPressureSet());
57 SGPRCriticalLimit
-= ErrorMargin
;
58 VGPRCriticalLimit
-= ErrorMargin
;
61 void GCNMaxOccupancySchedStrategy::initCandidate(SchedCandidate
&Cand
, SUnit
*SU
,
62 bool AtTop
, const RegPressureTracker
&RPTracker
,
63 const SIRegisterInfo
*SRI
,
64 unsigned SGPRPressure
,
65 unsigned VGPRPressure
) {
70 // getDownwardPressure() and getUpwardPressure() make temporary changes to
71 // the tracker, so we need to pass those function a non-const copy.
72 RegPressureTracker
&TempTracker
= const_cast<RegPressureTracker
&>(RPTracker
);
74 std::vector
<unsigned> Pressure
;
75 std::vector
<unsigned> MaxPressure
;
78 TempTracker
.getDownwardPressure(SU
->getInstr(), Pressure
, MaxPressure
);
80 // FIXME: I think for bottom up scheduling, the register pressure is cached
81 // and can be retrieved by DAG->getPressureDif(SU).
82 TempTracker
.getUpwardPressure(SU
->getInstr(), Pressure
, MaxPressure
);
85 unsigned NewSGPRPressure
= Pressure
[SRI
->getSGPRPressureSet()];
86 unsigned NewVGPRPressure
= Pressure
[SRI
->getVGPRPressureSet()];
88 // If two instructions increase the pressure of different register sets
89 // by the same amount, the generic scheduler will prefer to schedule the
90 // instruction that increases the set with the least amount of registers,
91 // which in our case would be SGPRs. This is rarely what we want, so
92 // when we report excess/critical register pressure, we do it either
93 // only for VGPRs or only for SGPRs.
95 // FIXME: Better heuristics to determine whether to prefer SGPRs or VGPRs.
96 const unsigned MaxVGPRPressureInc
= 16;
97 bool ShouldTrackVGPRs
= VGPRPressure
+ MaxVGPRPressureInc
>= VGPRExcessLimit
;
98 bool ShouldTrackSGPRs
= !ShouldTrackVGPRs
&& SGPRPressure
>= SGPRExcessLimit
;
101 // FIXME: We have to enter REG-EXCESS before we reach the actual threshold
102 // to increase the likelihood we don't go over the limits. We should improve
103 // the analysis to look through dependencies to find the path with the least
104 // register pressure.
106 // We only need to update the RPDelata for instructions that increase
107 // register pressure. Instructions that decrease or keep reg pressure
108 // the same will be marked as RegExcess in tryCandidate() when they
109 // are compared with instructions that increase the register pressure.
110 if (ShouldTrackVGPRs
&& NewVGPRPressure
>= VGPRExcessLimit
) {
111 Cand
.RPDelta
.Excess
= PressureChange(SRI
->getVGPRPressureSet());
112 Cand
.RPDelta
.Excess
.setUnitInc(NewVGPRPressure
- VGPRExcessLimit
);
115 if (ShouldTrackSGPRs
&& NewSGPRPressure
>= SGPRExcessLimit
) {
116 Cand
.RPDelta
.Excess
= PressureChange(SRI
->getSGPRPressureSet());
117 Cand
.RPDelta
.Excess
.setUnitInc(NewSGPRPressure
- SGPRExcessLimit
);
120 // Register pressure is considered 'CRITICAL' if it is approaching a value
121 // that would reduce the wave occupancy for the execution unit. When
122 // register pressure is 'CRITICAL', increading SGPR and VGPR pressure both
123 // has the same cost, so we don't need to prefer one over the other.
125 int SGPRDelta
= NewSGPRPressure
- SGPRCriticalLimit
;
126 int VGPRDelta
= NewVGPRPressure
- VGPRCriticalLimit
;
128 if (SGPRDelta
>= 0 || VGPRDelta
>= 0) {
129 if (SGPRDelta
> VGPRDelta
) {
130 Cand
.RPDelta
.CriticalMax
= PressureChange(SRI
->getSGPRPressureSet());
131 Cand
.RPDelta
.CriticalMax
.setUnitInc(SGPRDelta
);
133 Cand
.RPDelta
.CriticalMax
= PressureChange(SRI
->getVGPRPressureSet());
134 Cand
.RPDelta
.CriticalMax
.setUnitInc(VGPRDelta
);
139 // This function is mostly cut and pasted from
140 // GenericScheduler::pickNodeFromQueue()
141 void GCNMaxOccupancySchedStrategy::pickNodeFromQueue(SchedBoundary
&Zone
,
142 const CandPolicy
&ZonePolicy
,
143 const RegPressureTracker
&RPTracker
,
144 SchedCandidate
&Cand
) {
145 const SIRegisterInfo
*SRI
= static_cast<const SIRegisterInfo
*>(TRI
);
146 ArrayRef
<unsigned> Pressure
= RPTracker
.getRegSetPressureAtPos();
147 unsigned SGPRPressure
= Pressure
[SRI
->getSGPRPressureSet()];
148 unsigned VGPRPressure
= Pressure
[SRI
->getVGPRPressureSet()];
149 ReadyQueue
&Q
= Zone
.Available
;
150 for (SUnit
*SU
: Q
) {
152 SchedCandidate
TryCand(ZonePolicy
);
153 initCandidate(TryCand
, SU
, Zone
.isTop(), RPTracker
, SRI
,
154 SGPRPressure
, VGPRPressure
);
155 // Pass SchedBoundary only when comparing nodes from the same boundary.
156 SchedBoundary
*ZoneArg
= Cand
.AtTop
== TryCand
.AtTop
? &Zone
: nullptr;
157 GenericScheduler::tryCandidate(Cand
, TryCand
, ZoneArg
);
158 if (TryCand
.Reason
!= NoCand
) {
159 // Initialize resource delta if needed in case future heuristics query it.
160 if (TryCand
.ResDelta
== SchedResourceDelta())
161 TryCand
.initResourceDelta(Zone
.DAG
, SchedModel
);
162 Cand
.setBest(TryCand
);
167 // This function is mostly cut and pasted from
168 // GenericScheduler::pickNodeBidirectional()
169 SUnit
*GCNMaxOccupancySchedStrategy::pickNodeBidirectional(bool &IsTopNode
) {
170 // Schedule as far as possible in the direction of no choice. This is most
171 // efficient, but also provides the best heuristics for CriticalPSets.
172 if (SUnit
*SU
= Bot
.pickOnlyChoice()) {
176 if (SUnit
*SU
= Top
.pickOnlyChoice()) {
180 // Set the bottom-up policy based on the state of the current bottom zone and
181 // the instructions outside the zone, including the top zone.
182 CandPolicy BotPolicy
;
183 setPolicy(BotPolicy
, /*IsPostRA=*/false, Bot
, &Top
);
184 // Set the top-down policy based on the state of the current top zone and
185 // the instructions outside the zone, including the bottom zone.
186 CandPolicy TopPolicy
;
187 setPolicy(TopPolicy
, /*IsPostRA=*/false, Top
, &Bot
);
189 // See if BotCand is still valid (because we previously scheduled from Top).
190 LLVM_DEBUG(dbgs() << "Picking from Bot:\n");
191 if (!BotCand
.isValid() || BotCand
.SU
->isScheduled
||
192 BotCand
.Policy
!= BotPolicy
) {
193 BotCand
.reset(CandPolicy());
194 pickNodeFromQueue(Bot
, BotPolicy
, DAG
->getBotRPTracker(), BotCand
);
195 assert(BotCand
.Reason
!= NoCand
&& "failed to find the first candidate");
197 LLVM_DEBUG(traceCandidate(BotCand
));
200 // Check if the top Q has a better candidate.
201 LLVM_DEBUG(dbgs() << "Picking from Top:\n");
202 if (!TopCand
.isValid() || TopCand
.SU
->isScheduled
||
203 TopCand
.Policy
!= TopPolicy
) {
204 TopCand
.reset(CandPolicy());
205 pickNodeFromQueue(Top
, TopPolicy
, DAG
->getTopRPTracker(), TopCand
);
206 assert(TopCand
.Reason
!= NoCand
&& "failed to find the first candidate");
208 LLVM_DEBUG(traceCandidate(TopCand
));
211 // Pick best from BotCand and TopCand.
212 LLVM_DEBUG(dbgs() << "Top Cand: "; traceCandidate(TopCand
);
213 dbgs() << "Bot Cand: "; traceCandidate(BotCand
););
215 if (TopCand
.Reason
== BotCand
.Reason
) {
217 GenericSchedulerBase::CandReason TopReason
= TopCand
.Reason
;
218 TopCand
.Reason
= NoCand
;
219 GenericScheduler::tryCandidate(Cand
, TopCand
, nullptr);
220 if (TopCand
.Reason
!= NoCand
) {
221 Cand
.setBest(TopCand
);
223 TopCand
.Reason
= TopReason
;
226 if (TopCand
.Reason
== RegExcess
&& TopCand
.RPDelta
.Excess
.getUnitInc() <= 0) {
228 } else if (BotCand
.Reason
== RegExcess
&& BotCand
.RPDelta
.Excess
.getUnitInc() <= 0) {
230 } else if (TopCand
.Reason
== RegCritical
&& TopCand
.RPDelta
.CriticalMax
.getUnitInc() <= 0) {
232 } else if (BotCand
.Reason
== RegCritical
&& BotCand
.RPDelta
.CriticalMax
.getUnitInc() <= 0) {
235 if (BotCand
.Reason
> TopCand
.Reason
) {
242 LLVM_DEBUG(dbgs() << "Picking: "; traceCandidate(Cand
););
244 IsTopNode
= Cand
.AtTop
;
248 // This function is mostly cut and pasted from
249 // GenericScheduler::pickNode()
250 SUnit
*GCNMaxOccupancySchedStrategy::pickNode(bool &IsTopNode
) {
251 if (DAG
->top() == DAG
->bottom()) {
252 assert(Top
.Available
.empty() && Top
.Pending
.empty() &&
253 Bot
.Available
.empty() && Bot
.Pending
.empty() && "ReadyQ garbage");
258 if (RegionPolicy
.OnlyTopDown
) {
259 SU
= Top
.pickOnlyChoice();
262 TopCand
.reset(NoPolicy
);
263 pickNodeFromQueue(Top
, NoPolicy
, DAG
->getTopRPTracker(), TopCand
);
264 assert(TopCand
.Reason
!= NoCand
&& "failed to find a candidate");
268 } else if (RegionPolicy
.OnlyBottomUp
) {
269 SU
= Bot
.pickOnlyChoice();
272 BotCand
.reset(NoPolicy
);
273 pickNodeFromQueue(Bot
, NoPolicy
, DAG
->getBotRPTracker(), BotCand
);
274 assert(BotCand
.Reason
!= NoCand
&& "failed to find a candidate");
279 SU
= pickNodeBidirectional(IsTopNode
);
281 } while (SU
->isScheduled
);
283 if (SU
->isTopReady())
285 if (SU
->isBottomReady())
288 LLVM_DEBUG(dbgs() << "Scheduling SU(" << SU
->NodeNum
<< ") "
293 GCNScheduleDAGMILive::GCNScheduleDAGMILive(MachineSchedContext
*C
,
294 std::unique_ptr
<MachineSchedStrategy
> S
) :
295 ScheduleDAGMILive(C
, std::move(S
)),
296 ST(MF
.getSubtarget
<GCNSubtarget
>()),
297 MFI(*MF
.getInfo
<SIMachineFunctionInfo
>()),
298 StartingOccupancy(MFI
.getOccupancy()),
299 MinOccupancy(StartingOccupancy
), Stage(0), RegionIdx(0) {
301 LLVM_DEBUG(dbgs() << "Starting occupancy is " << StartingOccupancy
<< ".\n");
304 void GCNScheduleDAGMILive::schedule() {
306 // Just record regions at the first pass.
307 Regions
.push_back(std::make_pair(RegionBegin
, RegionEnd
));
311 std::vector
<MachineInstr
*> Unsched
;
312 Unsched
.reserve(NumRegionInstrs
);
313 for (auto &I
: *this) {
314 Unsched
.push_back(&I
);
317 GCNRegPressure PressureBefore
;
319 PressureBefore
= Pressure
[RegionIdx
];
321 LLVM_DEBUG(dbgs() << "Pressure before scheduling:\nRegion live-ins:";
322 GCNRPTracker::printLiveRegs(dbgs(), LiveIns
[RegionIdx
], MRI
);
323 dbgs() << "Region live-in pressure: ";
324 llvm::getRegPressure(MRI
, LiveIns
[RegionIdx
]).print(dbgs());
325 dbgs() << "Region register pressure: ";
326 PressureBefore
.print(dbgs()));
329 ScheduleDAGMILive::schedule();
330 Regions
[RegionIdx
] = std::make_pair(RegionBegin
, RegionEnd
);
335 // Check the results of scheduling.
336 GCNMaxOccupancySchedStrategy
&S
= (GCNMaxOccupancySchedStrategy
&)*SchedImpl
;
337 auto PressureAfter
= getRealRegPressure();
339 LLVM_DEBUG(dbgs() << "Pressure after scheduling: ";
340 PressureAfter
.print(dbgs()));
342 if (PressureAfter
.getSGPRNum() <= S
.SGPRCriticalLimit
&&
343 PressureAfter
.getVGPRNum() <= S
.VGPRCriticalLimit
) {
344 Pressure
[RegionIdx
] = PressureAfter
;
345 LLVM_DEBUG(dbgs() << "Pressure in desired limits, done.\n");
348 unsigned Occ
= MFI
.getOccupancy();
349 unsigned WavesAfter
= std::min(Occ
, PressureAfter
.getOccupancy(ST
));
350 unsigned WavesBefore
= std::min(Occ
, PressureBefore
.getOccupancy(ST
));
351 LLVM_DEBUG(dbgs() << "Occupancy before scheduling: " << WavesBefore
352 << ", after " << WavesAfter
<< ".\n");
354 // We could not keep current target occupancy because of the just scheduled
355 // region. Record new occupancy for next scheduling cycle.
356 unsigned NewOccupancy
= std::max(WavesAfter
, WavesBefore
);
357 // Allow memory bound functions to drop to 4 waves if not limited by an
359 if (WavesAfter
< WavesBefore
&& WavesAfter
< MinOccupancy
&&
360 WavesAfter
>= MFI
.getMinAllowedOccupancy()) {
361 LLVM_DEBUG(dbgs() << "Function is memory bound, allow occupancy drop up to "
362 << MFI
.getMinAllowedOccupancy() << " waves\n");
363 NewOccupancy
= WavesAfter
;
365 if (NewOccupancy
< MinOccupancy
) {
366 MinOccupancy
= NewOccupancy
;
367 MFI
.limitOccupancy(MinOccupancy
);
368 LLVM_DEBUG(dbgs() << "Occupancy lowered for the function to "
369 << MinOccupancy
<< ".\n");
372 if (WavesAfter
>= MinOccupancy
) {
373 Pressure
[RegionIdx
] = PressureAfter
;
377 LLVM_DEBUG(dbgs() << "Attempting to revert scheduling.\n");
378 RegionEnd
= RegionBegin
;
379 for (MachineInstr
*MI
: Unsched
) {
380 if (MI
->isDebugInstr())
383 if (MI
->getIterator() != RegionEnd
) {
385 BB
->insert(RegionEnd
, MI
);
386 if (!MI
->isDebugInstr())
387 LIS
->handleMove(*MI
, true);
389 // Reset read-undef flags and update them later.
390 for (auto &Op
: MI
->operands())
391 if (Op
.isReg() && Op
.isDef())
392 Op
.setIsUndef(false);
393 RegisterOperands RegOpers
;
394 RegOpers
.collect(*MI
, *TRI
, MRI
, ShouldTrackLaneMasks
, false);
395 if (!MI
->isDebugInstr()) {
396 if (ShouldTrackLaneMasks
) {
397 // Adjust liveness and add missing dead+read-undef flags.
398 SlotIndex SlotIdx
= LIS
->getInstructionIndex(*MI
).getRegSlot();
399 RegOpers
.adjustLaneLiveness(*LIS
, MRI
, SlotIdx
, MI
);
401 // Adjust for missing dead-def flags.
402 RegOpers
.detectDeadDefs(*MI
, *LIS
);
405 RegionEnd
= MI
->getIterator();
407 LLVM_DEBUG(dbgs() << "Scheduling " << *MI
);
409 RegionBegin
= Unsched
.front()->getIterator();
410 Regions
[RegionIdx
] = std::make_pair(RegionBegin
, RegionEnd
);
415 GCNRegPressure
GCNScheduleDAGMILive::getRealRegPressure() const {
416 GCNDownwardRPTracker
RPTracker(*LIS
);
417 RPTracker
.advance(begin(), end(), &LiveIns
[RegionIdx
]);
418 return RPTracker
.moveMaxPressure();
421 void GCNScheduleDAGMILive::computeBlockPressure(const MachineBasicBlock
*MBB
) {
422 GCNDownwardRPTracker
RPTracker(*LIS
);
424 // If the block has the only successor then live-ins of that successor are
425 // live-outs of the current block. We can reuse calculated live set if the
426 // successor will be sent to scheduling past current block.
427 const MachineBasicBlock
*OnlySucc
= nullptr;
428 if (MBB
->succ_size() == 1 && !(*MBB
->succ_begin())->empty()) {
429 SlotIndexes
*Ind
= LIS
->getSlotIndexes();
430 if (Ind
->getMBBStartIdx(MBB
) < Ind
->getMBBStartIdx(*MBB
->succ_begin()))
431 OnlySucc
= *MBB
->succ_begin();
434 // Scheduler sends regions from the end of the block upwards.
435 size_t CurRegion
= RegionIdx
;
436 for (size_t E
= Regions
.size(); CurRegion
!= E
; ++CurRegion
)
437 if (Regions
[CurRegion
].first
->getParent() != MBB
)
441 auto I
= MBB
->begin();
442 auto LiveInIt
= MBBLiveIns
.find(MBB
);
443 if (LiveInIt
!= MBBLiveIns
.end()) {
444 auto LiveIn
= std::move(LiveInIt
->second
);
445 RPTracker
.reset(*MBB
->begin(), &LiveIn
);
446 MBBLiveIns
.erase(LiveInIt
);
448 I
= Regions
[CurRegion
].first
;
453 I
= RPTracker
.getNext();
455 if (Regions
[CurRegion
].first
== I
) {
456 LiveIns
[CurRegion
] = RPTracker
.getLiveRegs();
457 RPTracker
.clearMaxPressure();
460 if (Regions
[CurRegion
].second
== I
) {
461 Pressure
[CurRegion
] = RPTracker
.moveMaxPressure();
462 if (CurRegion
-- == RegionIdx
)
465 RPTracker
.advanceToNext();
466 RPTracker
.advanceBeforeNext();
470 if (I
!= MBB
->end()) {
471 RPTracker
.advanceToNext();
472 RPTracker
.advance(MBB
->end());
474 RPTracker
.reset(*OnlySucc
->begin(), &RPTracker
.getLiveRegs());
475 RPTracker
.advanceBeforeNext();
476 MBBLiveIns
[OnlySucc
] = RPTracker
.moveLiveRegs();
480 void GCNScheduleDAGMILive::finalizeSchedule() {
481 GCNMaxOccupancySchedStrategy
&S
= (GCNMaxOccupancySchedStrategy
&)*SchedImpl
;
482 LLVM_DEBUG(dbgs() << "All regions recorded, starting actual scheduling.\n");
484 LiveIns
.resize(Regions
.size());
485 Pressure
.resize(Regions
.size());
490 MachineBasicBlock
*MBB
= nullptr;
493 // Retry function scheduling if we found resulting occupancy and it is
494 // lower than used for first pass scheduling. This will give more freedom
495 // to schedule low register pressure blocks.
496 // Code is partially copied from MachineSchedulerBase::scheduleRegions().
498 if (!LIS
|| StartingOccupancy
<= MinOccupancy
)
503 << "Retrying function scheduling with lowest recorded occupancy "
504 << MinOccupancy
<< ".\n");
506 S
.setTargetOccupancy(MinOccupancy
);
509 for (auto Region
: Regions
) {
510 RegionBegin
= Region
.first
;
511 RegionEnd
= Region
.second
;
513 if (RegionBegin
->getParent() != MBB
) {
514 if (MBB
) finishBlock();
515 MBB
= RegionBegin
->getParent();
518 computeBlockPressure(MBB
);
521 unsigned NumRegionInstrs
= std::distance(begin(), end());
522 enterRegion(MBB
, begin(), end(), NumRegionInstrs
);
524 // Skip empty scheduling regions (0 or 1 schedulable instructions).
525 if (begin() == end() || begin() == std::prev(end())) {
530 LLVM_DEBUG(dbgs() << "********** MI Scheduling **********\n");
531 LLVM_DEBUG(dbgs() << MF
.getName() << ":" << printMBBReference(*MBB
) << " "
532 << MBB
->getName() << "\n From: " << *begin()
534 if (RegionEnd
!= MBB
->end()) dbgs() << *RegionEnd
;
535 else dbgs() << "End";
536 dbgs() << " RegionInstrs: " << NumRegionInstrs
<< '\n');