[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Transforms / Vectorize / VPlanAnalysis.cpp
blob97a8a1803bbf5a5d4e8490773da572d50e425c91
1 //===- VPlanAnalysis.cpp - Various Analyses working on VPlan ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "VPlanAnalysis.h"
10 #include "VPlan.h"
11 #include "llvm/ADT/TypeSwitch.h"
13 using namespace llvm;
15 #define DEBUG_TYPE "vplan"
17 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPBlendRecipe *R) {
18 Type *ResTy = inferScalarType(R->getIncomingValue(0));
19 for (unsigned I = 1, E = R->getNumIncomingValues(); I != E; ++I) {
20 VPValue *Inc = R->getIncomingValue(I);
21 assert(inferScalarType(Inc) == ResTy &&
22 "different types inferred for different incoming values");
23 CachedTypes[Inc] = ResTy;
25 return ResTy;
28 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
29 switch (R->getOpcode()) {
30 case Instruction::Select: {
31 Type *ResTy = inferScalarType(R->getOperand(1));
32 VPValue *OtherV = R->getOperand(2);
33 assert(inferScalarType(OtherV) == ResTy &&
34 "different types inferred for different operands");
35 CachedTypes[OtherV] = ResTy;
36 return ResTy;
38 case VPInstruction::FirstOrderRecurrenceSplice: {
39 Type *ResTy = inferScalarType(R->getOperand(0));
40 VPValue *OtherV = R->getOperand(1);
41 assert(inferScalarType(OtherV) == ResTy &&
42 "different types inferred for different operands");
43 CachedTypes[OtherV] = ResTy;
44 return ResTy;
46 default:
47 break;
49 // Type inference not implemented for opcode.
50 LLVM_DEBUG({
51 dbgs() << "LV: Found unhandled opcode for: ";
52 R->getVPSingleValue()->dump();
53 });
54 llvm_unreachable("Unhandled opcode!");
57 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) {
58 unsigned Opcode = R->getOpcode();
59 switch (Opcode) {
60 case Instruction::ICmp:
61 case Instruction::FCmp:
62 return IntegerType::get(Ctx, 1);
63 case Instruction::UDiv:
64 case Instruction::SDiv:
65 case Instruction::SRem:
66 case Instruction::URem:
67 case Instruction::Add:
68 case Instruction::FAdd:
69 case Instruction::Sub:
70 case Instruction::FSub:
71 case Instruction::Mul:
72 case Instruction::FMul:
73 case Instruction::FDiv:
74 case Instruction::FRem:
75 case Instruction::Shl:
76 case Instruction::LShr:
77 case Instruction::AShr:
78 case Instruction::And:
79 case Instruction::Or:
80 case Instruction::Xor: {
81 Type *ResTy = inferScalarType(R->getOperand(0));
82 assert(ResTy == inferScalarType(R->getOperand(1)) &&
83 "types for both operands must match for binary op");
84 CachedTypes[R->getOperand(1)] = ResTy;
85 return ResTy;
87 case Instruction::FNeg:
88 case Instruction::Freeze:
89 return inferScalarType(R->getOperand(0));
90 default:
91 break;
94 // Type inference not implemented for opcode.
95 LLVM_DEBUG({
96 dbgs() << "LV: Found unhandled opcode for: ";
97 R->getVPSingleValue()->dump();
98 });
99 llvm_unreachable("Unhandled opcode!");
102 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenCallRecipe *R) {
103 auto &CI = *cast<CallInst>(R->getUnderlyingInstr());
104 return CI.getType();
107 Type *VPTypeAnalysis::inferScalarTypeForRecipe(
108 const VPWidenMemoryInstructionRecipe *R) {
109 assert(!R->isStore() && "Store recipes should not define any values");
110 return cast<LoadInst>(&R->getIngredient())->getType();
113 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) {
114 Type *ResTy = inferScalarType(R->getOperand(1));
115 VPValue *OtherV = R->getOperand(2);
116 assert(inferScalarType(OtherV) == ResTy &&
117 "different types inferred for different operands");
118 CachedTypes[OtherV] = ResTy;
119 return ResTy;
122 Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) {
123 switch (R->getUnderlyingInstr()->getOpcode()) {
124 case Instruction::Call: {
125 unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1);
126 return cast<Function>(R->getOperand(CallIdx)->getLiveInIRValue())
127 ->getReturnType();
129 case Instruction::UDiv:
130 case Instruction::SDiv:
131 case Instruction::SRem:
132 case Instruction::URem:
133 case Instruction::Add:
134 case Instruction::FAdd:
135 case Instruction::Sub:
136 case Instruction::FSub:
137 case Instruction::Mul:
138 case Instruction::FMul:
139 case Instruction::FDiv:
140 case Instruction::FRem:
141 case Instruction::Shl:
142 case Instruction::LShr:
143 case Instruction::AShr:
144 case Instruction::And:
145 case Instruction::Or:
146 case Instruction::Xor: {
147 Type *ResTy = inferScalarType(R->getOperand(0));
148 assert(ResTy == inferScalarType(R->getOperand(1)) &&
149 "inferred types for operands of binary op don't match");
150 CachedTypes[R->getOperand(1)] = ResTy;
151 return ResTy;
153 case Instruction::Select: {
154 Type *ResTy = inferScalarType(R->getOperand(1));
155 assert(ResTy == inferScalarType(R->getOperand(2)) &&
156 "inferred types for operands of select op don't match");
157 CachedTypes[R->getOperand(2)] = ResTy;
158 return ResTy;
160 case Instruction::ICmp:
161 case Instruction::FCmp:
162 return IntegerType::get(Ctx, 1);
163 case Instruction::Alloca:
164 case Instruction::BitCast:
165 case Instruction::Trunc:
166 case Instruction::SExt:
167 case Instruction::ZExt:
168 case Instruction::FPExt:
169 case Instruction::FPTrunc:
170 case Instruction::ExtractValue:
171 case Instruction::SIToFP:
172 case Instruction::UIToFP:
173 case Instruction::FPToSI:
174 case Instruction::FPToUI:
175 case Instruction::PtrToInt:
176 case Instruction::IntToPtr:
177 return R->getUnderlyingInstr()->getType();
178 case Instruction::Freeze:
179 case Instruction::FNeg:
180 case Instruction::GetElementPtr:
181 return inferScalarType(R->getOperand(0));
182 case Instruction::Load:
183 return cast<LoadInst>(R->getUnderlyingInstr())->getType();
184 case Instruction::Store:
185 // FIXME: VPReplicateRecipes with store opcodes still define a result
186 // VPValue, so we need to handle them here. Remove the code here once this
187 // is modeled accurately in VPlan.
188 return Type::getVoidTy(Ctx);
189 default:
190 break;
192 // Type inference not implemented for opcode.
193 LLVM_DEBUG({
194 dbgs() << "LV: Found unhandled opcode for: ";
195 R->getVPSingleValue()->dump();
197 llvm_unreachable("Unhandled opcode");
200 Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
201 if (Type *CachedTy = CachedTypes.lookup(V))
202 return CachedTy;
204 if (V->isLiveIn())
205 return V->getLiveInIRValue()->getType();
207 Type *ResultTy =
208 TypeSwitch<const VPRecipeBase *, Type *>(V->getDefiningRecipe())
209 .Case<VPCanonicalIVPHIRecipe, VPFirstOrderRecurrencePHIRecipe,
210 VPReductionPHIRecipe, VPWidenPointerInductionRecipe>(
211 [this](const auto *R) {
212 // Handle header phi recipes, except VPWienIntOrFpInduction
213 // which needs special handling due it being possibly truncated.
214 // TODO: consider inferring/caching type of siblings, e.g.,
215 // backedge value, here and in cases below.
216 return inferScalarType(R->getStartValue());
218 .Case<VPWidenIntOrFpInductionRecipe, VPDerivedIVRecipe>(
219 [](const auto *R) { return R->getScalarType(); })
220 .Case<VPPredInstPHIRecipe, VPWidenPHIRecipe, VPScalarIVStepsRecipe,
221 VPWidenGEPRecipe>([this](const VPRecipeBase *R) {
222 return inferScalarType(R->getOperand(0));
224 .Case<VPBlendRecipe, VPInstruction, VPWidenRecipe, VPReplicateRecipe,
225 VPWidenCallRecipe, VPWidenMemoryInstructionRecipe,
226 VPWidenSelectRecipe>(
227 [this](const auto *R) { return inferScalarTypeForRecipe(R); })
228 .Case<VPInterleaveRecipe>([V](const VPInterleaveRecipe *R) {
229 // TODO: Use info from interleave group.
230 return V->getUnderlyingValue()->getType();
232 .Case<VPWidenCastRecipe>(
233 [](const VPWidenCastRecipe *R) { return R->getResultType(); });
234 assert(ResultTy && "could not infer type for the given VPValue");
235 CachedTypes[V] = ResultTy;
236 return ResultTy;