[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / CodeGen / RDFRegisters.cpp
blob7ce00a66b3ae6c888d1051a166bc1a2275964fc8
1 //===- RDFRegisters.cpp ---------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/ADT/BitVector.h"
10 #include "llvm/CodeGen/MachineFunction.h"
11 #include "llvm/CodeGen/MachineInstr.h"
12 #include "llvm/CodeGen/MachineOperand.h"
13 #include "llvm/CodeGen/RDFRegisters.h"
14 #include "llvm/CodeGen/TargetRegisterInfo.h"
15 #include "llvm/MC/LaneBitmask.h"
16 #include "llvm/MC/MCRegisterInfo.h"
17 #include "llvm/Support/ErrorHandling.h"
18 #include "llvm/Support/Format.h"
19 #include "llvm/Support/MathExtras.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include <cassert>
22 #include <cstdint>
23 #include <set>
24 #include <utility>
26 namespace llvm::rdf {
28 PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
29 const MachineFunction &mf)
30 : TRI(tri) {
31 RegInfos.resize(TRI.getNumRegs());
33 BitVector BadRC(TRI.getNumRegs());
34 for (const TargetRegisterClass *RC : TRI.regclasses()) {
35 for (MCPhysReg R : *RC) {
36 RegInfo &RI = RegInfos[R];
37 if (RI.RegClass != nullptr && !BadRC[R]) {
38 if (RC->LaneMask != RI.RegClass->LaneMask) {
39 BadRC.set(R);
40 RI.RegClass = nullptr;
42 } else
43 RI.RegClass = RC;
47 UnitInfos.resize(TRI.getNumRegUnits());
49 for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
50 if (UnitInfos[U].Reg != 0)
51 continue;
52 MCRegUnitRootIterator R(U, &TRI);
53 assert(R.isValid());
54 RegisterId F = *R;
55 ++R;
56 if (R.isValid()) {
57 UnitInfos[U].Mask = LaneBitmask::getAll();
58 UnitInfos[U].Reg = F;
59 } else {
60 for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
61 std::pair<uint32_t, LaneBitmask> P = *I;
62 UnitInfo &UI = UnitInfos[P.first];
63 UI.Reg = F;
64 UI.Mask = P.second;
69 for (const uint32_t *RM : TRI.getRegMasks())
70 RegMasks.insert(RM);
71 for (const MachineBasicBlock &B : mf)
72 for (const MachineInstr &In : B)
73 for (const MachineOperand &Op : In.operands())
74 if (Op.isRegMask())
75 RegMasks.insert(Op.getRegMask());
77 MaskInfos.resize(RegMasks.size() + 1);
78 for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
79 BitVector PU(TRI.getNumRegUnits());
80 const uint32_t *MB = RegMasks.get(M);
81 for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
82 if (!(MB[I / 32] & (1u << (I % 32))))
83 continue;
84 for (MCRegUnit Unit : TRI.regunits(MCRegister::from(I)))
85 PU.set(Unit);
87 MaskInfos[M].Units = PU.flip();
90 AliasInfos.resize(TRI.getNumRegUnits());
91 for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
92 BitVector AS(TRI.getNumRegs());
93 for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
94 for (MCPhysReg S : TRI.superregs_inclusive(*R))
95 AS.set(S);
96 AliasInfos[U].Regs = AS;
100 bool PhysicalRegisterInfo::alias(RegisterRef RA, RegisterRef RB) const {
101 return !disjoint(getUnits(RA), getUnits(RB));
104 std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
105 // Do not include Reg in the alias set.
106 std::set<RegisterId> AS;
107 assert(!RegisterRef::isUnitId(Reg) && "No units allowed");
108 if (RegisterRef::isMaskId(Reg)) {
109 // XXX SLOW
110 const uint32_t *MB = getRegMaskBits(Reg);
111 for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
112 if (MB[i / 32] & (1u << (i % 32)))
113 continue;
114 AS.insert(i);
116 return AS;
119 assert(RegisterRef::isRegId(Reg));
120 for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
121 AS.insert(*AI);
123 return AS;
126 std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
127 std::set<RegisterId> Units;
129 if (RR.Reg == 0)
130 return Units; // Empty
132 if (RR.isReg()) {
133 if (RR.Mask.none())
134 return Units; // Empty
135 for (MCRegUnitMaskIterator UM(RR.idx(), &TRI); UM.isValid(); ++UM) {
136 auto [U, M] = *UM;
137 if ((M & RR.Mask).any())
138 Units.insert(U);
140 return Units;
143 assert(RR.isMask());
144 unsigned NumRegs = TRI.getNumRegs();
145 const uint32_t *MB = getRegMaskBits(RR.idx());
146 for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
147 uint32_t C = ~MB[I]; // Clobbered regs
148 if (I == 0) // Reg 0 should be ignored
149 C &= maskLeadingOnes<unsigned>(31);
150 if (I + 1 == E && NumRegs % 32 != 0) // Last word may be partial
151 C &= maskTrailingOnes<unsigned>(NumRegs % 32);
152 if (C == 0)
153 continue;
154 while (C != 0) {
155 unsigned T = llvm::countr_zero(C);
156 unsigned CR = 32 * I + T; // Clobbered reg
157 for (MCRegUnit U : TRI.regunits(CR))
158 Units.insert(U);
159 C &= ~(1u << T);
162 return Units;
165 RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
166 if (RR.Reg == R)
167 return RR;
168 if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
169 return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
170 if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
171 const RegInfo &RI = RegInfos[R];
172 LaneBitmask RCM =
173 RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
174 LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(Idx, RR.Mask);
175 return RegisterRef(R, M & RCM);
177 llvm_unreachable("Invalid arguments: unrelated registers?");
180 bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
181 if (!A.isReg() || !B.isReg()) {
182 // For non-regs, or comparing reg and non-reg, use only the Reg member.
183 return A.Reg == B.Reg;
186 if (A.Reg == B.Reg)
187 return A.Mask == B.Mask;
189 // Compare reg units lexicographically.
190 MCRegUnitMaskIterator AI(A.Reg, &getTRI());
191 MCRegUnitMaskIterator BI(B.Reg, &getTRI());
192 while (AI.isValid() && BI.isValid()) {
193 auto [AReg, AMask] = *AI;
194 auto [BReg, BMask] = *BI;
196 // If both iterators point to a unit contained in both A and B, then
197 // compare the units.
198 if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
199 if (AReg != BReg)
200 return false;
201 // Units are equal, move on to the next ones.
202 ++AI;
203 ++BI;
204 continue;
207 if ((AMask & A.Mask).none())
208 ++AI;
209 if ((BMask & B.Mask).none())
210 ++BI;
212 // One or both have reached the end.
213 return static_cast<int>(AI.isValid()) == static_cast<int>(BI.isValid());
216 bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
217 if (!A.isReg() || !B.isReg()) {
218 // For non-regs, or comparing reg and non-reg, use only the Reg member.
219 return A.Reg < B.Reg;
222 if (A.Reg == B.Reg)
223 return A.Mask < B.Mask;
224 if (A.Mask == B.Mask)
225 return A.Reg < B.Reg;
227 // Compare reg units lexicographically.
228 llvm::MCRegUnitMaskIterator AI(A.Reg, &getTRI());
229 llvm::MCRegUnitMaskIterator BI(B.Reg, &getTRI());
230 while (AI.isValid() && BI.isValid()) {
231 auto [AReg, AMask] = *AI;
232 auto [BReg, BMask] = *BI;
234 // If both iterators point to a unit contained in both A and B, then
235 // compare the units.
236 if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
237 if (AReg != BReg)
238 return AReg < BReg;
239 // Units are equal, move on to the next ones.
240 ++AI;
241 ++BI;
242 continue;
245 if ((AMask & A.Mask).none())
246 ++AI;
247 if ((BMask & B.Mask).none())
248 ++BI;
250 // One or both have reached the end: assume invalid < valid.
251 return static_cast<int>(AI.isValid()) < static_cast<int>(BI.isValid());
254 void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
255 if (A.Reg == 0 || A.isReg()) {
256 if (0 < A.idx() && A.idx() < TRI.getNumRegs())
257 OS << TRI.getName(A.idx());
258 else
259 OS << printReg(A.idx(), &TRI);
260 OS << PrintLaneMaskShort(A.Mask);
261 } else if (A.isUnit()) {
262 OS << printRegUnit(A.idx(), &TRI);
263 } else {
264 assert(A.isMask());
265 // RegMask SS flag is preserved by idx().
266 unsigned Idx = Register::stackSlot2Index(A.idx());
267 const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
268 OS << "M#" << format(Fmt, Idx);
272 void PhysicalRegisterInfo::print(raw_ostream &OS, const RegisterAggr &A) const {
273 OS << '{';
274 for (unsigned U : A.units())
275 OS << ' ' << printRegUnit(U, &TRI);
276 OS << " }";
279 bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
280 if (RR.isMask())
281 return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
283 for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
284 std::pair<uint32_t, LaneBitmask> P = *U;
285 if ((P.second & RR.Mask).any())
286 if (Units.test(P.first))
287 return true;
289 return false;
292 bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
293 if (RR.isMask()) {
294 BitVector T(PRI.getMaskUnits(RR.Reg));
295 return T.reset(Units).none();
298 for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
299 std::pair<uint32_t, LaneBitmask> P = *U;
300 if ((P.second & RR.Mask).any())
301 if (!Units.test(P.first))
302 return false;
304 return true;
307 RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
308 if (RR.isMask()) {
309 Units |= PRI.getMaskUnits(RR.Reg);
310 return *this;
313 for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
314 std::pair<uint32_t, LaneBitmask> P = *U;
315 if ((P.second & RR.Mask).any())
316 Units.set(P.first);
318 return *this;
321 RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
322 Units |= RG.Units;
323 return *this;
326 RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
327 return intersect(RegisterAggr(PRI).insert(RR));
330 RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
331 Units &= RG.Units;
332 return *this;
335 RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
336 return clear(RegisterAggr(PRI).insert(RR));
339 RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
340 Units.reset(RG.Units);
341 return *this;
344 RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
345 RegisterAggr T(PRI);
346 T.insert(RR).intersect(*this);
347 if (T.empty())
348 return RegisterRef();
349 RegisterRef NR = T.makeRegRef();
350 assert(NR);
351 return NR;
354 RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
355 return RegisterAggr(PRI).insert(RR).clear(*this).makeRegRef();
358 RegisterRef RegisterAggr::makeRegRef() const {
359 int U = Units.find_first();
360 if (U < 0)
361 return RegisterRef();
363 // Find the set of all registers that are aliased to all the units
364 // in this aggregate.
366 // Get all the registers aliased to the first unit in the bit vector.
367 BitVector Regs = PRI.getUnitAliases(U);
368 U = Units.find_next(U);
370 // For each other unit, intersect it with the set of all registers
371 // aliased that unit.
372 while (U >= 0) {
373 Regs &= PRI.getUnitAliases(U);
374 U = Units.find_next(U);
377 // If there is at least one register remaining, pick the first one,
378 // and consolidate the masks of all of its units contained in this
379 // aggregate.
381 int F = Regs.find_first();
382 if (F <= 0)
383 return RegisterRef();
385 LaneBitmask M;
386 for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
387 std::pair<uint32_t, LaneBitmask> P = *I;
388 if (Units.test(P.first))
389 M |= P.second;
391 return RegisterRef(F, M);
394 RegisterAggr::ref_iterator::ref_iterator(const RegisterAggr &RG, bool End)
395 : Owner(&RG) {
396 for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
397 RegisterRef R = RG.PRI.getRefForUnit(U);
398 Masks[R.Reg] |= R.Mask;
400 Pos = End ? Masks.end() : Masks.begin();
401 Index = End ? Masks.size() : 0;
404 raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A) {
405 A.getPRI().print(OS, A);
406 return OS;
409 raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
410 if (P.Mask.all())
411 return OS;
412 if (P.Mask.none())
413 return OS << ":*none*";
415 LaneBitmask::Type Val = P.Mask.getAsInteger();
416 if ((Val & 0xffff) == Val)
417 return OS << ':' << format("%04llX", Val);
418 if ((Val & 0xffffffff) == Val)
419 return OS << ':' << format("%08llX", Val);
420 return OS << ':' << PrintLaneMask(P.Mask);
423 } // namespace llvm::rdf