[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Target / RISCV / GISel / RISCVRegisterBankInfo.cpp
blobcf0ff63a5e51c29da5e1ad49dc524f63edd23970
1 //===-- RISCVRegisterBankInfo.cpp -------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the RegisterBankInfo class for RISC-V.
10 /// \todo This should be generated by TableGen.
11 //===----------------------------------------------------------------------===//
13 #include "RISCVRegisterBankInfo.h"
14 #include "MCTargetDesc/RISCVMCTargetDesc.h"
15 #include "RISCVSubtarget.h"
16 #include "llvm/CodeGen/MachineRegisterInfo.h"
17 #include "llvm/CodeGen/RegisterBank.h"
18 #include "llvm/CodeGen/RegisterBankInfo.h"
19 #include "llvm/CodeGen/TargetRegisterInfo.h"
21 #define GET_TARGET_REGBANK_IMPL
22 #include "RISCVGenRegisterBank.inc"
24 namespace llvm {
25 namespace RISCV {
27 const RegisterBankInfo::PartialMapping PartMappings[] = {
28 {0, 32, GPRBRegBank},
29 {0, 64, GPRBRegBank},
30 {0, 32, FPRBRegBank},
31 {0, 64, FPRBRegBank},
34 enum PartialMappingIdx {
35 PMI_GPRB32 = 0,
36 PMI_GPRB64 = 1,
37 PMI_FPRB32 = 2,
38 PMI_FPRB64 = 3,
41 const RegisterBankInfo::ValueMapping ValueMappings[] = {
42 // Invalid value mapping.
43 {nullptr, 0},
44 // Maximum 3 GPR operands; 32 bit.
45 {&PartMappings[PMI_GPRB32], 1},
46 {&PartMappings[PMI_GPRB32], 1},
47 {&PartMappings[PMI_GPRB32], 1},
48 // Maximum 3 GPR operands; 64 bit.
49 {&PartMappings[PMI_GPRB64], 1},
50 {&PartMappings[PMI_GPRB64], 1},
51 {&PartMappings[PMI_GPRB64], 1},
52 // Maximum 3 FPR operands; 32 bit.
53 {&PartMappings[PMI_FPRB32], 1},
54 {&PartMappings[PMI_FPRB32], 1},
55 {&PartMappings[PMI_FPRB32], 1},
56 // Maximum 3 FPR operands; 64 bit.
57 {&PartMappings[PMI_FPRB64], 1},
58 {&PartMappings[PMI_FPRB64], 1},
59 {&PartMappings[PMI_FPRB64], 1},
62 enum ValueMappingIdx {
63 InvalidIdx = 0,
64 GPRB32Idx = 1,
65 GPRB64Idx = 4,
66 FPRB32Idx = 7,
67 FPRB64Idx = 10,
69 } // namespace RISCV
70 } // namespace llvm
72 using namespace llvm;
74 RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
75 : RISCVGenRegisterBankInfo(HwMode) {}
77 const RegisterBank &
78 RISCVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
79 LLT Ty) const {
80 switch (RC.getID()) {
81 default:
82 llvm_unreachable("Register class not supported");
83 case RISCV::GPRRegClassID:
84 case RISCV::GPRF16RegClassID:
85 case RISCV::GPRF32RegClassID:
86 case RISCV::GPRNoX0RegClassID:
87 case RISCV::GPRNoX0X2RegClassID:
88 case RISCV::GPRJALRRegClassID:
89 case RISCV::GPRTCRegClassID:
90 case RISCV::GPRC_and_GPRTCRegClassID:
91 case RISCV::GPRCRegClassID:
92 case RISCV::GPRC_and_SR07RegClassID:
93 case RISCV::SR07RegClassID:
94 case RISCV::SPRegClassID:
95 case RISCV::GPRX0RegClassID:
96 return getRegBank(RISCV::GPRBRegBankID);
97 case RISCV::FPR64RegClassID:
98 case RISCV::FPR16RegClassID:
99 case RISCV::FPR32RegClassID:
100 case RISCV::FPR64CRegClassID:
101 case RISCV::FPR32CRegClassID:
102 return getRegBank(RISCV::FPRBRegBankID);
103 case RISCV::VMRegClassID:
104 case RISCV::VRRegClassID:
105 case RISCV::VRNoV0RegClassID:
106 case RISCV::VRM2RegClassID:
107 case RISCV::VRM2NoV0RegClassID:
108 case RISCV::VRM4RegClassID:
109 case RISCV::VRM4NoV0RegClassID:
110 case RISCV::VMV0RegClassID:
111 case RISCV::VRM2_with_sub_vrm1_0_in_VMV0RegClassID:
112 case RISCV::VRM4_with_sub_vrm1_0_in_VMV0RegClassID:
113 case RISCV::VRM8RegClassID:
114 case RISCV::VRM8NoV0RegClassID:
115 case RISCV::VRM8_with_sub_vrm1_0_in_VMV0RegClassID:
116 return getRegBank(RISCV::VRBRegBankID);
120 static const RegisterBankInfo::ValueMapping *getFPValueMapping(unsigned Size) {
121 assert(Size == 32 || Size == 64);
122 unsigned Idx = Size == 64 ? RISCV::FPRB64Idx : RISCV::FPRB32Idx;
123 return &RISCV::ValueMappings[Idx];
126 /// Returns whether opcode \p Opc is a pre-isel generic floating-point opcode,
127 /// having only floating-point operands.
128 /// FIXME: this is copied from target AArch64. Needs some code refactor here to
129 /// put this function in GlobalISel/Utils.cpp.
130 static bool isPreISelGenericFloatingPointOpcode(unsigned Opc) {
131 switch (Opc) {
132 case TargetOpcode::G_FADD:
133 case TargetOpcode::G_FSUB:
134 case TargetOpcode::G_FMUL:
135 case TargetOpcode::G_FMA:
136 case TargetOpcode::G_FDIV:
137 case TargetOpcode::G_FCONSTANT:
138 case TargetOpcode::G_FPEXT:
139 case TargetOpcode::G_FPTRUNC:
140 case TargetOpcode::G_FCEIL:
141 case TargetOpcode::G_FFLOOR:
142 case TargetOpcode::G_FNEARBYINT:
143 case TargetOpcode::G_FNEG:
144 case TargetOpcode::G_FCOPYSIGN:
145 case TargetOpcode::G_FCOS:
146 case TargetOpcode::G_FSIN:
147 case TargetOpcode::G_FLOG10:
148 case TargetOpcode::G_FLOG:
149 case TargetOpcode::G_FLOG2:
150 case TargetOpcode::G_FSQRT:
151 case TargetOpcode::G_FABS:
152 case TargetOpcode::G_FEXP:
153 case TargetOpcode::G_FRINT:
154 case TargetOpcode::G_INTRINSIC_TRUNC:
155 case TargetOpcode::G_INTRINSIC_ROUND:
156 case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
157 case TargetOpcode::G_FMAXNUM:
158 case TargetOpcode::G_FMINNUM:
159 case TargetOpcode::G_FMAXIMUM:
160 case TargetOpcode::G_FMINIMUM:
161 return true;
163 return false;
166 // TODO: Make this more like AArch64?
167 bool RISCVRegisterBankInfo::hasFPConstraints(
168 const MachineInstr &MI, const MachineRegisterInfo &MRI,
169 const TargetRegisterInfo &TRI) const {
170 if (isPreISelGenericFloatingPointOpcode(MI.getOpcode()))
171 return true;
173 // If we have a copy instruction, we could be feeding floating point
174 // instructions.
175 if (MI.getOpcode() != TargetOpcode::COPY)
176 return false;
178 return getRegBank(MI.getOperand(0).getReg(), MRI, TRI) == &RISCV::FPRBRegBank;
181 bool RISCVRegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
182 const MachineRegisterInfo &MRI,
183 const TargetRegisterInfo &TRI) const {
184 switch (MI.getOpcode()) {
185 case TargetOpcode::G_FPTOSI:
186 case TargetOpcode::G_FPTOUI:
187 case TargetOpcode::G_FCMP:
188 return true;
189 default:
190 break;
193 return hasFPConstraints(MI, MRI, TRI);
196 bool RISCVRegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
197 const MachineRegisterInfo &MRI,
198 const TargetRegisterInfo &TRI) const {
199 switch (MI.getOpcode()) {
200 case TargetOpcode::G_SITOFP:
201 case TargetOpcode::G_UITOFP:
202 return true;
203 default:
204 break;
207 return hasFPConstraints(MI, MRI, TRI);
210 bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
211 Register Def, const MachineRegisterInfo &MRI,
212 const TargetRegisterInfo &TRI) const {
213 return any_of(
214 MRI.use_nodbg_instructions(Def),
215 [&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
218 const RegisterBankInfo::InstructionMapping &
219 RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
220 const unsigned Opc = MI.getOpcode();
222 // Try the default logic for non-generic instructions that are either copies
223 // or already have some operands assigned to banks.
224 if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
225 const InstructionMapping &Mapping = getInstrMappingImpl(MI);
226 if (Mapping.isValid())
227 return Mapping;
230 const MachineFunction &MF = *MI.getParent()->getParent();
231 const MachineRegisterInfo &MRI = MF.getRegInfo();
232 const TargetSubtargetInfo &STI = MF.getSubtarget();
233 const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
235 unsigned GPRSize = getMaximumSize(RISCV::GPRBRegBankID);
236 assert((GPRSize == 32 || GPRSize == 64) && "Unexpected GPR size");
238 unsigned NumOperands = MI.getNumOperands();
239 const ValueMapping *GPRValueMapping =
240 &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
241 : RISCV::GPRB32Idx];
243 switch (Opc) {
244 case TargetOpcode::G_ADD:
245 case TargetOpcode::G_SUB:
246 case TargetOpcode::G_SHL:
247 case TargetOpcode::G_ASHR:
248 case TargetOpcode::G_LSHR:
249 case TargetOpcode::G_AND:
250 case TargetOpcode::G_OR:
251 case TargetOpcode::G_XOR:
252 case TargetOpcode::G_MUL:
253 case TargetOpcode::G_SDIV:
254 case TargetOpcode::G_SREM:
255 case TargetOpcode::G_SMULH:
256 case TargetOpcode::G_SMAX:
257 case TargetOpcode::G_SMIN:
258 case TargetOpcode::G_UDIV:
259 case TargetOpcode::G_UREM:
260 case TargetOpcode::G_UMULH:
261 case TargetOpcode::G_UMAX:
262 case TargetOpcode::G_UMIN:
263 case TargetOpcode::G_PTR_ADD:
264 case TargetOpcode::G_PTRTOINT:
265 case TargetOpcode::G_INTTOPTR:
266 case TargetOpcode::G_TRUNC:
267 case TargetOpcode::G_ANYEXT:
268 case TargetOpcode::G_SEXT:
269 case TargetOpcode::G_ZEXT:
270 case TargetOpcode::G_SEXTLOAD:
271 case TargetOpcode::G_ZEXTLOAD:
272 return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
273 NumOperands);
274 case TargetOpcode::G_FADD:
275 case TargetOpcode::G_FSUB:
276 case TargetOpcode::G_FMUL:
277 case TargetOpcode::G_FDIV:
278 case TargetOpcode::G_FABS:
279 case TargetOpcode::G_FNEG:
280 case TargetOpcode::G_FSQRT:
281 case TargetOpcode::G_FMAXNUM:
282 case TargetOpcode::G_FMINNUM: {
283 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
284 return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
285 getFPValueMapping(Ty.getSizeInBits()),
286 NumOperands);
288 case TargetOpcode::G_IMPLICIT_DEF: {
289 Register Dst = MI.getOperand(0).getReg();
290 auto Mapping = GPRValueMapping;
291 // FIXME: May need to do a better job determining when to use FPRB.
292 // For example, the look through COPY case:
293 // %0:_(s32) = G_IMPLICIT_DEF
294 // %1:_(s32) = COPY %0
295 // $f10_d = COPY %1(s32)
296 if (anyUseOnlyUseFP(Dst, MRI, TRI))
297 Mapping = getFPValueMapping(MRI.getType(Dst).getSizeInBits());
298 return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping,
299 NumOperands);
303 SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
305 switch (Opc) {
306 case TargetOpcode::G_LOAD: {
307 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
308 OpdsMapping[0] = GPRValueMapping;
309 OpdsMapping[1] = GPRValueMapping;
310 // Use FPR64 for s64 loads on rv32.
311 if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
312 assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
313 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
314 break;
317 // Check if that load feeds fp instructions.
318 // In that case, we want the default mapping to be on FPR
319 // instead of blind map every scalar to GPR.
320 if (anyUseOnlyUseFP(MI.getOperand(0).getReg(), MRI, TRI))
321 // If we have at least one direct use in a FP instruction,
322 // assume this was a floating point load in the IR. If it was
323 // not, we would have had a bitcast before reaching that
324 // instruction.
325 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
327 break;
329 case TargetOpcode::G_STORE: {
330 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
331 OpdsMapping[0] = GPRValueMapping;
332 OpdsMapping[1] = GPRValueMapping;
333 // Use FPR64 for s64 stores on rv32.
334 if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
335 assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
336 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
337 break;
340 MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
341 if (onlyDefinesFP(*DefMI, MRI, TRI))
342 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
343 break;
345 case TargetOpcode::G_SELECT: {
346 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
348 // Try to minimize the number of copies. If we have more floating point
349 // constrained values than not, then we'll put everything on FPR. Otherwise,
350 // everything has to be on GPR.
351 unsigned NumFP = 0;
353 // Use FPR64 for s64 select on rv32.
354 if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
355 NumFP = 3;
356 } else {
357 // Check if the uses of the result always produce floating point values.
359 // For example:
361 // %z = G_SELECT %cond %x %y
362 // fpr = G_FOO %z ...
363 if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
364 [&](const MachineInstr &UseMI) {
365 return onlyUsesFP(UseMI, MRI, TRI);
367 ++NumFP;
369 // Check if the defs of the source values always produce floating point
370 // values.
372 // For example:
374 // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
375 // %z = G_SELECT %cond %x %y
377 // Also check whether or not the sources have already been decided to be
378 // FPR. Keep track of this.
380 // This doesn't check the condition, since the condition is always an
381 // integer.
382 for (unsigned Idx = 2; Idx < 4; ++Idx) {
383 Register VReg = MI.getOperand(Idx).getReg();
384 MachineInstr *DefMI = MRI.getVRegDef(VReg);
385 if (getRegBank(VReg, MRI, TRI) == &RISCV::FPRBRegBank ||
386 onlyDefinesFP(*DefMI, MRI, TRI))
387 ++NumFP;
391 // Condition operand is always GPR.
392 OpdsMapping[1] = GPRValueMapping;
394 const ValueMapping *Mapping = GPRValueMapping;
395 if (NumFP >= 2)
396 Mapping = getFPValueMapping(Ty.getSizeInBits());
398 OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] = Mapping;
399 break;
401 case TargetOpcode::G_FPTOSI:
402 case TargetOpcode::G_FPTOUI:
403 case RISCV::G_FCLASS: {
404 LLT Ty = MRI.getType(MI.getOperand(1).getReg());
405 OpdsMapping[0] = GPRValueMapping;
406 OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
407 break;
409 case TargetOpcode::G_SITOFP:
410 case TargetOpcode::G_UITOFP: {
411 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
412 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
413 OpdsMapping[1] = GPRValueMapping;
414 break;
416 case TargetOpcode::G_FCMP: {
417 LLT Ty = MRI.getType(MI.getOperand(2).getReg());
419 unsigned Size = Ty.getSizeInBits();
420 assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
422 OpdsMapping[0] = GPRValueMapping;
423 OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
424 break;
426 case TargetOpcode::G_MERGE_VALUES: {
427 // Use FPR64 for s64 merge on rv32.
428 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
429 if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
430 assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
431 OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
432 OpdsMapping[1] = GPRValueMapping;
433 OpdsMapping[2] = GPRValueMapping;
435 break;
437 case TargetOpcode::G_UNMERGE_VALUES: {
438 // Use FPR64 for s64 unmerge on rv32.
439 LLT Ty = MRI.getType(MI.getOperand(2).getReg());
440 if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
441 assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
442 OpdsMapping[0] = GPRValueMapping;
443 OpdsMapping[1] = GPRValueMapping;
444 OpdsMapping[2] = getFPValueMapping(Ty.getSizeInBits());
446 break;
448 default:
449 // By default map all scalars to GPR.
450 for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
451 auto &MO = MI.getOperand(Idx);
452 if (!MO.isReg() || !MO.getReg())
453 continue;
454 LLT Ty = MRI.getType(MO.getReg());
455 if (!Ty.isValid())
456 continue;
458 if (isPreISelGenericFloatingPointOpcode(Opc))
459 OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits());
460 else
461 OpdsMapping[Idx] = GPRValueMapping;
463 break;
466 return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
467 getOperandsMapping(OpdsMapping), NumOperands);