[ARM] Better patterns for fp <> predicate vectors
[llvm-complete.git] / lib / Target / AMDGPU / GCNRegPressure.cpp
blob39460fbd8a84d63fa1259aa040b12fa5fb7877f5
1 //===- GCNRegPressure.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "GCNRegPressure.h"
10 #include "AMDGPUSubtarget.h"
11 #include "SIRegisterInfo.h"
12 #include "llvm/ADT/SmallVector.h"
13 #include "llvm/CodeGen/LiveInterval.h"
14 #include "llvm/CodeGen/LiveIntervals.h"
15 #include "llvm/CodeGen/MachineInstr.h"
16 #include "llvm/CodeGen/MachineOperand.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/RegisterPressure.h"
19 #include "llvm/CodeGen/SlotIndexes.h"
20 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #include "llvm/Config/llvm-config.h"
22 #include "llvm/MC/LaneBitmask.h"
23 #include "llvm/Support/Compiler.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include <algorithm>
28 #include <cassert>
30 using namespace llvm;
32 #define DEBUG_TYPE "machine-scheduler"
34 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35 LLVM_DUMP_METHOD
36 void llvm::printLivesAt(SlotIndex SI,
37 const LiveIntervals &LIS,
38 const MachineRegisterInfo &MRI) {
39 dbgs() << "Live regs at " << SI << ": "
40 << *LIS.getInstructionFromIndex(SI);
41 unsigned Num = 0;
42 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43 const unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
44 if (!LIS.hasInterval(Reg))
45 continue;
46 const auto &LI = LIS.getInterval(Reg);
47 if (LI.hasSubRanges()) {
48 bool firstTime = true;
49 for (const auto &S : LI.subranges()) {
50 if (!S.liveAt(SI)) continue;
51 if (firstTime) {
52 dbgs() << " " << printReg(Reg, MRI.getTargetRegisterInfo())
53 << '\n';
54 firstTime = false;
56 dbgs() << " " << S << '\n';
57 ++Num;
59 } else if (LI.liveAt(SI)) {
60 dbgs() << " " << LI << '\n';
61 ++Num;
64 if (!Num) dbgs() << " <none>\n";
66 #endif
68 bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
69 const GCNRPTracker::LiveRegSet &S2) {
70 if (S1.size() != S2.size())
71 return false;
73 for (const auto &P : S1) {
74 auto I = S2.find(P.first);
75 if (I == S2.end() || I->second != P.second)
76 return false;
78 return true;
82 ///////////////////////////////////////////////////////////////////////////////
83 // GCNRegPressure
85 unsigned GCNRegPressure::getRegKind(unsigned Reg,
86 const MachineRegisterInfo &MRI) {
87 assert(TargetRegisterInfo::isVirtualRegister(Reg));
88 const auto RC = MRI.getRegClass(Reg);
89 auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90 return STI->isSGPRClass(RC) ?
91 (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92 STI->hasAGPRs(RC) ?
93 (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
94 (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
97 void GCNRegPressure::inc(unsigned Reg,
98 LaneBitmask PrevMask,
99 LaneBitmask NewMask,
100 const MachineRegisterInfo &MRI) {
101 if (NewMask == PrevMask)
102 return;
104 int Sign = 1;
105 if (NewMask < PrevMask) {
106 std::swap(NewMask, PrevMask);
107 Sign = -1;
109 #ifndef NDEBUG
110 const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
111 #endif
112 switch (auto Kind = getRegKind(Reg, MRI)) {
113 case SGPR32:
114 case VGPR32:
115 case AGPR32:
116 assert(PrevMask.none() && NewMask == MaxMask);
117 Value[Kind] += Sign;
118 break;
120 case SGPR_TUPLE:
121 case VGPR_TUPLE:
122 case AGPR_TUPLE:
123 assert(NewMask < MaxMask || NewMask == MaxMask);
124 assert(PrevMask < NewMask);
126 Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
127 Sign * (~PrevMask & NewMask).getNumLanes();
129 if (PrevMask.none()) {
130 assert(NewMask.any());
131 Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
133 break;
135 default: llvm_unreachable("Unknown register kind");
139 bool GCNRegPressure::less(const GCNSubtarget &ST,
140 const GCNRegPressure& O,
141 unsigned MaxOccupancy) const {
142 const auto SGPROcc = std::min(MaxOccupancy,
143 ST.getOccupancyWithNumSGPRs(getSGPRNum()));
144 const auto VGPROcc = std::min(MaxOccupancy,
145 ST.getOccupancyWithNumVGPRs(getVGPRNum()));
146 const auto OtherSGPROcc = std::min(MaxOccupancy,
147 ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
148 const auto OtherVGPROcc = std::min(MaxOccupancy,
149 ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
151 const auto Occ = std::min(SGPROcc, VGPROcc);
152 const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
153 if (Occ != OtherOcc)
154 return Occ > OtherOcc;
156 bool SGPRImportant = SGPROcc < VGPROcc;
157 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
159 // if both pressures disagree on what is more important compare vgprs
160 if (SGPRImportant != OtherSGPRImportant) {
161 SGPRImportant = false;
164 // compare large regs pressure
165 bool SGPRFirst = SGPRImportant;
166 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
167 if (SGPRFirst) {
168 auto SW = getSGPRTuplesWeight();
169 auto OtherSW = O.getSGPRTuplesWeight();
170 if (SW != OtherSW)
171 return SW < OtherSW;
172 } else {
173 auto VW = getVGPRTuplesWeight();
174 auto OtherVW = O.getVGPRTuplesWeight();
175 if (VW != OtherVW)
176 return VW < OtherVW;
179 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
180 (getVGPRNum() < O.getVGPRNum());
183 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
184 LLVM_DUMP_METHOD
185 void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
186 OS << "VGPRs: " << getVGPRNum();
187 if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
188 OS << ", SGPRs: " << getSGPRNum();
189 if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
190 OS << ", LVGPR WT: " << getVGPRTuplesWeight()
191 << ", LSGPR WT: " << getSGPRTuplesWeight();
192 if (ST) OS << " -> Occ: " << getOccupancy(*ST);
193 OS << '\n';
195 #endif
197 static LaneBitmask getDefRegMask(const MachineOperand &MO,
198 const MachineRegisterInfo &MRI) {
199 assert(MO.isDef() && MO.isReg() &&
200 TargetRegisterInfo::isVirtualRegister(MO.getReg()));
202 // We don't rely on read-undef flag because in case of tentative schedule
203 // tracking it isn't set correctly yet. This works correctly however since
204 // use mask has been tracked before using LIS.
205 return MO.getSubReg() == 0 ?
206 MRI.getMaxLaneMaskForVReg(MO.getReg()) :
207 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
210 static LaneBitmask getUsedRegMask(const MachineOperand &MO,
211 const MachineRegisterInfo &MRI,
212 const LiveIntervals &LIS) {
213 assert(MO.isUse() && MO.isReg() &&
214 TargetRegisterInfo::isVirtualRegister(MO.getReg()));
216 if (auto SubReg = MO.getSubReg())
217 return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219 auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
220 if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
221 return MaxMask;
223 // For a tentative schedule LIS isn't updated yet but livemask should remain
224 // the same on any schedule. Subreg defs can be reordered but they all must
225 // dominate uses anyway.
226 auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
227 return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
230 static SmallVector<RegisterMaskPair, 8>
231 collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
232 const MachineRegisterInfo &MRI) {
233 SmallVector<RegisterMaskPair, 8> Res;
234 for (const auto &MO : MI.operands()) {
235 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()))
236 continue;
237 if (!MO.isUse() || !MO.readsReg())
238 continue;
240 auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242 auto Reg = MO.getReg();
243 auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
244 return RM.RegUnit == Reg;
246 if (I != Res.end())
247 I->LaneMask |= UsedMask;
248 else
249 Res.push_back(RegisterMaskPair(Reg, UsedMask));
251 return Res;
254 ///////////////////////////////////////////////////////////////////////////////
255 // GCNRPTracker
257 LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
258 SlotIndex SI,
259 const LiveIntervals &LIS,
260 const MachineRegisterInfo &MRI) {
261 LaneBitmask LiveMask;
262 const auto &LI = LIS.getInterval(Reg);
263 if (LI.hasSubRanges()) {
264 for (const auto &S : LI.subranges())
265 if (S.liveAt(SI)) {
266 LiveMask |= S.LaneMask;
267 assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
268 LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270 } else if (LI.liveAt(SI)) {
271 LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273 return LiveMask;
276 GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
277 const LiveIntervals &LIS,
278 const MachineRegisterInfo &MRI) {
279 GCNRPTracker::LiveRegSet LiveRegs;
280 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
281 auto Reg = TargetRegisterInfo::index2VirtReg(I);
282 if (!LIS.hasInterval(Reg))
283 continue;
284 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
285 if (LiveMask.any())
286 LiveRegs[Reg] = LiveMask;
288 return LiveRegs;
291 void GCNRPTracker::reset(const MachineInstr &MI,
292 const LiveRegSet *LiveRegsCopy,
293 bool After) {
294 const MachineFunction &MF = *MI.getMF();
295 MRI = &MF.getRegInfo();
296 if (LiveRegsCopy) {
297 if (&LiveRegs != LiveRegsCopy)
298 LiveRegs = *LiveRegsCopy;
299 } else {
300 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
301 : getLiveRegsBefore(MI, LIS);
304 MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
307 void GCNUpwardRPTracker::reset(const MachineInstr &MI,
308 const LiveRegSet *LiveRegsCopy) {
309 GCNRPTracker::reset(MI, LiveRegsCopy, true);
312 void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
313 assert(MRI && "call reset first");
315 LastTrackedMI = &MI;
317 if (MI.isDebugInstr())
318 return;
320 auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322 // calc pressure at the MI (defs + uses)
323 auto AtMIPressure = CurPressure;
324 for (const auto &U : RegUses) {
325 auto LiveMask = LiveRegs[U.RegUnit];
326 AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328 // update max pressure
329 MaxPressure = max(AtMIPressure, MaxPressure);
331 for (const auto &MO : MI.defs()) {
332 if (!MO.isReg() || !TargetRegisterInfo::isVirtualRegister(MO.getReg()) ||
333 MO.isDead())
334 continue;
336 auto Reg = MO.getReg();
337 auto I = LiveRegs.find(Reg);
338 if (I == LiveRegs.end())
339 continue;
340 auto &LiveMask = I->second;
341 auto PrevMask = LiveMask;
342 LiveMask &= ~getDefRegMask(MO, *MRI);
343 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
344 if (LiveMask.none())
345 LiveRegs.erase(I);
347 for (const auto &U : RegUses) {
348 auto &LiveMask = LiveRegs[U.RegUnit];
349 auto PrevMask = LiveMask;
350 LiveMask |= U.LaneMask;
351 CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
353 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
356 bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
357 const LiveRegSet *LiveRegsCopy) {
358 MRI = &MI.getParent()->getParent()->getRegInfo();
359 LastTrackedMI = nullptr;
360 MBBEnd = MI.getParent()->end();
361 NextMI = &MI;
362 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
363 if (NextMI == MBBEnd)
364 return false;
365 GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
366 return true;
369 bool GCNDownwardRPTracker::advanceBeforeNext() {
370 assert(MRI && "call reset first");
372 NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
373 if (NextMI == MBBEnd)
374 return false;
376 SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
377 assert(SI.isValid());
379 // Remove dead registers or mask bits.
380 for (auto &It : LiveRegs) {
381 const LiveInterval &LI = LIS.getInterval(It.first);
382 if (LI.hasSubRanges()) {
383 for (const auto &S : LI.subranges()) {
384 if (!S.liveAt(SI)) {
385 auto PrevMask = It.second;
386 It.second &= ~S.LaneMask;
387 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
390 } else if (!LI.liveAt(SI)) {
391 auto PrevMask = It.second;
392 It.second = LaneBitmask::getNone();
393 CurPressure.inc(It.first, PrevMask, It.second, *MRI);
395 if (It.second.none())
396 LiveRegs.erase(It.first);
399 MaxPressure = max(MaxPressure, CurPressure);
401 return true;
404 void GCNDownwardRPTracker::advanceToNext() {
405 LastTrackedMI = &*NextMI++;
407 // Add new registers or mask bits.
408 for (const auto &MO : LastTrackedMI->defs()) {
409 if (!MO.isReg())
410 continue;
411 unsigned Reg = MO.getReg();
412 if (!TargetRegisterInfo::isVirtualRegister(Reg))
413 continue;
414 auto &LiveMask = LiveRegs[Reg];
415 auto PrevMask = LiveMask;
416 LiveMask |= getDefRegMask(MO, *MRI);
417 CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
420 MaxPressure = max(MaxPressure, CurPressure);
423 bool GCNDownwardRPTracker::advance() {
424 // If we have just called reset live set is actual.
425 if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
426 return false;
427 advanceToNext();
428 return true;
431 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
432 while (NextMI != End)
433 if (!advance()) return false;
434 return true;
437 bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
438 MachineBasicBlock::const_iterator End,
439 const LiveRegSet *LiveRegsCopy) {
440 reset(*Begin, LiveRegsCopy);
441 return advance(End);
444 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
445 LLVM_DUMP_METHOD
446 static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
447 const GCNRPTracker::LiveRegSet &TrackedLR,
448 const TargetRegisterInfo *TRI) {
449 for (auto const &P : TrackedLR) {
450 auto I = LISLR.find(P.first);
451 if (I == LISLR.end()) {
452 dbgs() << " " << printReg(P.first, TRI)
453 << ":L" << PrintLaneMask(P.second)
454 << " isn't found in LIS reported set\n";
456 else if (I->second != P.second) {
457 dbgs() << " " << printReg(P.first, TRI)
458 << " masks doesn't match: LIS reported "
459 << PrintLaneMask(I->second)
460 << ", tracked "
461 << PrintLaneMask(P.second)
462 << '\n';
465 for (auto const &P : LISLR) {
466 auto I = TrackedLR.find(P.first);
467 if (I == TrackedLR.end()) {
468 dbgs() << " " << printReg(P.first, TRI)
469 << ":L" << PrintLaneMask(P.second)
470 << " isn't found in tracked set\n";
475 bool GCNUpwardRPTracker::isValid() const {
476 const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
477 const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
478 const auto &TrackedLR = LiveRegs;
480 if (!isEqual(LISLR, TrackedLR)) {
481 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
482 " LIS reported livesets mismatch:\n";
483 printLivesAt(SI, LIS, *MRI);
484 reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
485 return false;
488 auto LISPressure = getRegPressure(*MRI, LISLR);
489 if (LISPressure != CurPressure) {
490 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
491 CurPressure.print(dbgs());
492 dbgs() << "LIS rpt: ";
493 LISPressure.print(dbgs());
494 return false;
496 return true;
499 void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
500 const MachineRegisterInfo &MRI) {
501 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
502 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
503 unsigned Reg = TargetRegisterInfo::index2VirtReg(I);
504 auto It = LiveRegs.find(Reg);
505 if (It != LiveRegs.end() && It->second.any())
506 OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
507 << PrintLaneMask(It->second);
509 OS << '\n';
511 #endif