[gn build] Port 3986cffe8112
[llvm-project.git] / llvm / lib / Transforms / Vectorize / SandboxVectorizer / Passes / BottomUpVec.cpp
blobd44199609838d79042fa6de749287b9008098d0c
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
20 namespace llvm {
22 static cl::opt<unsigned>
23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24 cl::desc("Override the vector register size in bits, "
25 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28 cl::desc("Allow non-power-of-2 vectorization."));
30 namespace sandboxir {
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33 : FunctionPass("bottom-up-vec"),
34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37 unsigned OpIdx) {
38 SmallVector<Value *, 4> Operands;
39 for (Value *BndlV : Bndl) {
40 auto *BndlI = cast<Instruction>(BndlV);
41 Operands.push_back(BndlI->getOperand(OpIdx));
43 return Operands;
46 static BasicBlock::iterator
47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
48 // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49 auto *BotI = cast<Instruction>(
50 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
51 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
52 }));
53 // If Bndl contains Arguments or Constants, use the beginning of the BB.
54 return std::next(BotI->getIterator());
57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
58 ArrayRef<Value *> Operands) {
59 Change = true;
60 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
61 "Expect Instructions!");
62 auto &Ctx = Bndl[0]->getContext();
64 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
65 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
67 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
69 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
70 switch (Opcode) {
71 case Instruction::Opcode::ZExt:
72 case Instruction::Opcode::SExt:
73 case Instruction::Opcode::FPToUI:
74 case Instruction::Opcode::FPToSI:
75 case Instruction::Opcode::FPExt:
76 case Instruction::Opcode::PtrToInt:
77 case Instruction::Opcode::IntToPtr:
78 case Instruction::Opcode::SIToFP:
79 case Instruction::Opcode::UIToFP:
80 case Instruction::Opcode::Trunc:
81 case Instruction::Opcode::FPTrunc:
82 case Instruction::Opcode::BitCast: {
83 assert(Operands.size() == 1u && "Casts are unary!");
84 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
86 case Instruction::Opcode::FCmp:
87 case Instruction::Opcode::ICmp: {
88 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
89 assert(all_of(drop_begin(Bndl),
90 [Pred](auto *SBV) {
91 return cast<CmpInst>(SBV)->getPredicate() == Pred;
92 }) &&
93 "Expected same predicate across bundle.");
94 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
95 "VCmp");
97 case Instruction::Opcode::Select: {
98 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
99 Ctx, "Vec");
101 case Instruction::Opcode::FNeg: {
102 auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
103 auto OpC = UOp0->getOpcode();
104 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
105 Ctx, "Vec");
107 case Instruction::Opcode::Add:
108 case Instruction::Opcode::FAdd:
109 case Instruction::Opcode::Sub:
110 case Instruction::Opcode::FSub:
111 case Instruction::Opcode::Mul:
112 case Instruction::Opcode::FMul:
113 case Instruction::Opcode::UDiv:
114 case Instruction::Opcode::SDiv:
115 case Instruction::Opcode::FDiv:
116 case Instruction::Opcode::URem:
117 case Instruction::Opcode::SRem:
118 case Instruction::Opcode::FRem:
119 case Instruction::Opcode::Shl:
120 case Instruction::Opcode::LShr:
121 case Instruction::Opcode::AShr:
122 case Instruction::Opcode::And:
123 case Instruction::Opcode::Or:
124 case Instruction::Opcode::Xor: {
125 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
126 auto *LHS = Operands[0];
127 auto *RHS = Operands[1];
128 return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
129 BinOp0, WhereIt, Ctx, "Vec");
131 case Instruction::Opcode::Load: {
132 auto *Ld0 = cast<LoadInst>(Bndl[0]);
133 Value *Ptr = Ld0->getPointerOperand();
134 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
136 case Instruction::Opcode::Store: {
137 auto Align = cast<StoreInst>(Bndl[0])->getAlign();
138 Value *Val = Operands[0];
139 Value *Ptr = Operands[1];
140 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
142 case Instruction::Opcode::Br:
143 case Instruction::Opcode::Ret:
144 case Instruction::Opcode::PHI:
145 case Instruction::Opcode::AddrSpaceCast:
146 case Instruction::Opcode::Call:
147 case Instruction::Opcode::GetElementPtr:
148 llvm_unreachable("Unimplemented");
149 break;
150 default:
151 llvm_unreachable("Unimplemented");
152 break;
154 llvm_unreachable("Missing switch case!");
155 // TODO: Propagate debug info.
158 void BottomUpVec::tryEraseDeadInstrs() {
159 // Visiting the dead instructions bottom-to-top.
160 SmallVector<Instruction *> SortedDeadInstrCandidates(
161 DeadInstrCandidates.begin(), DeadInstrCandidates.end());
162 sort(SortedDeadInstrCandidates,
163 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
164 for (Instruction *I : reverse(SortedDeadInstrCandidates)) {
165 if (I->hasNUses(0))
166 I->eraseFromParent();
168 DeadInstrCandidates.clear();
171 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
172 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
174 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
175 unsigned Lanes = VecUtils::getNumLanes(ToPack);
176 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
178 // Create a series of pack instructions.
179 Value *LastInsert = PoisonValue::get(VecTy);
181 Context &Ctx = ToPack[0]->getContext();
183 unsigned InsertIdx = 0;
184 for (Value *Elm : ToPack) {
185 // An element can be either scalar or vector. We need to generate different
186 // IR for each case.
187 if (Elm->getType()->isVectorTy()) {
188 unsigned NumElms =
189 cast<FixedVectorType>(Elm->getType())->getNumElements();
190 for (auto ExtrLane : seq<int>(0, NumElms)) {
191 // We generate extract-insert pairs, for each lane in `Elm`.
192 Constant *ExtrLaneC =
193 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
194 // This may return a Constant if Elm is a Constant.
195 auto *ExtrI =
196 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
197 if (!isa<Constant>(ExtrI))
198 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
199 Constant *InsertLaneC =
200 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
201 // This may also return a Constant if ExtrI is a Constant.
202 auto *InsertI = InsertElementInst::create(
203 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
204 if (!isa<Constant>(InsertI)) {
205 LastInsert = InsertI;
206 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
209 } else {
210 Constant *InsertLaneC =
211 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
212 // This may be folded into a Constant if LastInsert is a Constant. In
213 // that case we only collect the last constant.
214 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
215 WhereIt, Ctx, "Pack");
216 if (auto *NewI = dyn_cast<Instruction>(LastInsert))
217 WhereIt = std::next(NewI->getIterator());
220 return LastInsert;
223 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
224 for (Value *V : Bndl)
225 DeadInstrCandidates.insert(cast<Instruction>(V));
226 // Also collect the GEPs of vectorized loads and stores.
227 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
228 switch (Opcode) {
229 case Instruction::Opcode::Load: {
230 for (Value *V : drop_begin(Bndl))
231 if (auto *Ptr =
232 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
233 DeadInstrCandidates.insert(Ptr);
234 break;
236 case Instruction::Opcode::Store: {
237 for (Value *V : drop_begin(Bndl))
238 if (auto *Ptr =
239 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
240 DeadInstrCandidates.insert(Ptr);
241 break;
243 default:
244 break;
248 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
249 Value *NewVec = nullptr;
250 const auto &LegalityRes = Legality->canVectorize(Bndl);
251 switch (LegalityRes.getSubclassID()) {
252 case LegalityResultID::Widen: {
253 auto *I = cast<Instruction>(Bndl[0]);
254 SmallVector<Value *, 2> VecOperands;
255 switch (I->getOpcode()) {
256 case Instruction::Opcode::Load:
257 // Don't recurse towards the pointer operand.
258 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
259 break;
260 case Instruction::Opcode::Store: {
261 // Don't recurse towards the pointer operand.
262 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
263 VecOperands.push_back(VecOp);
264 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
265 break;
267 default:
268 // Visit all operands.
269 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
270 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
271 VecOperands.push_back(VecOp);
273 break;
275 NewVec = createVectorInstr(Bndl, VecOperands);
277 // Collect any potentially dead scalar instructions, including the original
278 // scalars and pointer operands of loads/stores.
279 if (NewVec != nullptr)
280 collectPotentiallyDeadInstrs(Bndl);
281 break;
283 case LegalityResultID::Pack: {
284 // If we can't vectorize the seeds then just return.
285 if (Depth == 0)
286 return nullptr;
287 NewVec = createPack(Bndl);
288 break;
291 return NewVec;
294 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
295 DeadInstrCandidates.clear();
296 Legality->clear();
297 vectorizeRec(Bndl, /*Depth=*/0);
298 tryEraseDeadInstrs();
299 return Change;
302 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
303 Legality = std::make_unique<LegalityAnalysis>(
304 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
305 F.getContext());
306 Change = false;
307 const auto &DL = F.getParent()->getDataLayout();
308 unsigned VecRegBits =
309 OverrideVecRegBits != 0
310 ? OverrideVecRegBits
311 : A.getTTI()
312 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
313 .getFixedValue();
315 // TODO: Start from innermost BBs first
316 for (auto &BB : F) {
317 SeedCollector SC(&BB, A.getScalarEvolution());
318 for (SeedBundle &Seeds : SC.getStoreSeeds()) {
319 unsigned ElmBits =
320 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
321 Seeds[Seeds.getFirstUnusedElementIdx()])),
322 DL);
324 auto DivideBy2 = [](unsigned Num) {
325 auto Floor = VecUtils::getFloorPowerOf2(Num);
326 if (Floor == Num)
327 return Floor / 2;
328 return Floor;
330 // Try to create the largest vector supported by the target. If it fails
331 // reduce the vector size by half.
332 for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
333 Seeds.getNumUnusedBits() / ElmBits);
334 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
335 if (Seeds.allUsed())
336 break;
337 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
338 // the slice. This could be quite expensive, so we enforce a limit.
339 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
340 OE = Seeds.size();
341 Offset + 1 < OE; Offset += 1) {
342 // Seeds are getting used as we vectorize, so skip them.
343 if (Seeds.isUsed(Offset))
344 continue;
345 if (Seeds.allUsed())
346 break;
348 auto SeedSlice =
349 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
350 if (SeedSlice.empty())
351 continue;
353 assert(SeedSlice.size() >= 2 && "Should have been rejected!");
355 // TODO: If vectorization succeeds, run the RegionPassManager on the
356 // resulting region.
358 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
359 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
360 SeedSlice.end());
361 Change |= tryVectorize(SeedSliceVals);
366 return Change;
369 } // namespace sandboxir
370 } // namespace llvm