1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
22 static cl::opt
<unsigned>
23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden
,
24 cl::desc("Override the vector register size in bits, "
25 "which is otherwise found by querying TTI."));
27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden
,
28 cl::desc("Allow non-power-of-2 vectorization."));
32 BottomUpVec::BottomUpVec(StringRef Pipeline
)
33 : FunctionPass("bottom-up-vec"),
34 RPM("rpm", Pipeline
, SandboxVectorizerPassBuilder::createRegionPass
) {}
36 static SmallVector
<Value
*, 4> getOperand(ArrayRef
<Value
*> Bndl
,
38 SmallVector
<Value
*, 4> Operands
;
39 for (Value
*BndlV
: Bndl
) {
40 auto *BndlI
= cast
<Instruction
>(BndlV
);
41 Operands
.push_back(BndlI
->getOperand(OpIdx
));
46 static BasicBlock::iterator
47 getInsertPointAfterInstrs(ArrayRef
<Value
*> Instrs
) {
48 // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49 auto *BotI
= cast
<Instruction
>(
50 *std::max_element(Instrs
.begin(), Instrs
.end(), [](auto *V1
, auto *V2
) {
51 return cast
<Instruction
>(V1
)->comesBefore(cast
<Instruction
>(V2
));
53 // If Bndl contains Arguments or Constants, use the beginning of the BB.
54 return std::next(BotI
->getIterator());
57 Value
*BottomUpVec::createVectorInstr(ArrayRef
<Value
*> Bndl
,
58 ArrayRef
<Value
*> Operands
) {
60 assert(all_of(Bndl
, [](auto *V
) { return isa
<Instruction
>(V
); }) &&
61 "Expect Instructions!");
62 auto &Ctx
= Bndl
[0]->getContext();
64 Type
*ScalarTy
= VecUtils::getElementType(Utils::getExpectedType(Bndl
[0]));
65 auto *VecTy
= VecUtils::getWideType(ScalarTy
, VecUtils::getNumLanes(Bndl
));
67 BasicBlock::iterator WhereIt
= getInsertPointAfterInstrs(Bndl
);
69 auto Opcode
= cast
<Instruction
>(Bndl
[0])->getOpcode();
71 case Instruction::Opcode::ZExt
:
72 case Instruction::Opcode::SExt
:
73 case Instruction::Opcode::FPToUI
:
74 case Instruction::Opcode::FPToSI
:
75 case Instruction::Opcode::FPExt
:
76 case Instruction::Opcode::PtrToInt
:
77 case Instruction::Opcode::IntToPtr
:
78 case Instruction::Opcode::SIToFP
:
79 case Instruction::Opcode::UIToFP
:
80 case Instruction::Opcode::Trunc
:
81 case Instruction::Opcode::FPTrunc
:
82 case Instruction::Opcode::BitCast
: {
83 assert(Operands
.size() == 1u && "Casts are unary!");
84 return CastInst::create(VecTy
, Opcode
, Operands
[0], WhereIt
, Ctx
, "VCast");
86 case Instruction::Opcode::FCmp
:
87 case Instruction::Opcode::ICmp
: {
88 auto Pred
= cast
<CmpInst
>(Bndl
[0])->getPredicate();
89 assert(all_of(drop_begin(Bndl
),
91 return cast
<CmpInst
>(SBV
)->getPredicate() == Pred
;
93 "Expected same predicate across bundle.");
94 return CmpInst::create(Pred
, Operands
[0], Operands
[1], WhereIt
, Ctx
,
97 case Instruction::Opcode::Select
: {
98 return SelectInst::create(Operands
[0], Operands
[1], Operands
[2], WhereIt
,
101 case Instruction::Opcode::FNeg
: {
102 auto *UOp0
= cast
<UnaryOperator
>(Bndl
[0]);
103 auto OpC
= UOp0
->getOpcode();
104 return UnaryOperator::createWithCopiedFlags(OpC
, Operands
[0], UOp0
, WhereIt
,
107 case Instruction::Opcode::Add
:
108 case Instruction::Opcode::FAdd
:
109 case Instruction::Opcode::Sub
:
110 case Instruction::Opcode::FSub
:
111 case Instruction::Opcode::Mul
:
112 case Instruction::Opcode::FMul
:
113 case Instruction::Opcode::UDiv
:
114 case Instruction::Opcode::SDiv
:
115 case Instruction::Opcode::FDiv
:
116 case Instruction::Opcode::URem
:
117 case Instruction::Opcode::SRem
:
118 case Instruction::Opcode::FRem
:
119 case Instruction::Opcode::Shl
:
120 case Instruction::Opcode::LShr
:
121 case Instruction::Opcode::AShr
:
122 case Instruction::Opcode::And
:
123 case Instruction::Opcode::Or
:
124 case Instruction::Opcode::Xor
: {
125 auto *BinOp0
= cast
<BinaryOperator
>(Bndl
[0]);
126 auto *LHS
= Operands
[0];
127 auto *RHS
= Operands
[1];
128 return BinaryOperator::createWithCopiedFlags(BinOp0
->getOpcode(), LHS
, RHS
,
129 BinOp0
, WhereIt
, Ctx
, "Vec");
131 case Instruction::Opcode::Load
: {
132 auto *Ld0
= cast
<LoadInst
>(Bndl
[0]);
133 Value
*Ptr
= Ld0
->getPointerOperand();
134 return LoadInst::create(VecTy
, Ptr
, Ld0
->getAlign(), WhereIt
, Ctx
, "VecL");
136 case Instruction::Opcode::Store
: {
137 auto Align
= cast
<StoreInst
>(Bndl
[0])->getAlign();
138 Value
*Val
= Operands
[0];
139 Value
*Ptr
= Operands
[1];
140 return StoreInst::create(Val
, Ptr
, Align
, WhereIt
, Ctx
);
142 case Instruction::Opcode::Br
:
143 case Instruction::Opcode::Ret
:
144 case Instruction::Opcode::PHI
:
145 case Instruction::Opcode::AddrSpaceCast
:
146 case Instruction::Opcode::Call
:
147 case Instruction::Opcode::GetElementPtr
:
148 llvm_unreachable("Unimplemented");
151 llvm_unreachable("Unimplemented");
154 llvm_unreachable("Missing switch case!");
155 // TODO: Propagate debug info.
158 void BottomUpVec::tryEraseDeadInstrs() {
159 // Visiting the dead instructions bottom-to-top.
160 SmallVector
<Instruction
*> SortedDeadInstrCandidates(
161 DeadInstrCandidates
.begin(), DeadInstrCandidates
.end());
162 sort(SortedDeadInstrCandidates
,
163 [](Instruction
*I1
, Instruction
*I2
) { return I1
->comesBefore(I2
); });
164 for (Instruction
*I
: reverse(SortedDeadInstrCandidates
)) {
166 I
->eraseFromParent();
168 DeadInstrCandidates
.clear();
171 Value
*BottomUpVec::createPack(ArrayRef
<Value
*> ToPack
) {
172 BasicBlock::iterator WhereIt
= getInsertPointAfterInstrs(ToPack
);
174 Type
*ScalarTy
= VecUtils::getCommonScalarType(ToPack
);
175 unsigned Lanes
= VecUtils::getNumLanes(ToPack
);
176 Type
*VecTy
= VecUtils::getWideType(ScalarTy
, Lanes
);
178 // Create a series of pack instructions.
179 Value
*LastInsert
= PoisonValue::get(VecTy
);
181 Context
&Ctx
= ToPack
[0]->getContext();
183 unsigned InsertIdx
= 0;
184 for (Value
*Elm
: ToPack
) {
185 // An element can be either scalar or vector. We need to generate different
187 if (Elm
->getType()->isVectorTy()) {
189 cast
<FixedVectorType
>(Elm
->getType())->getNumElements();
190 for (auto ExtrLane
: seq
<int>(0, NumElms
)) {
191 // We generate extract-insert pairs, for each lane in `Elm`.
192 Constant
*ExtrLaneC
=
193 ConstantInt::getSigned(Type::getInt32Ty(Ctx
), ExtrLane
);
194 // This may return a Constant if Elm is a Constant.
196 ExtractElementInst::create(Elm
, ExtrLaneC
, WhereIt
, Ctx
, "VPack");
197 if (!isa
<Constant
>(ExtrI
))
198 WhereIt
= std::next(cast
<Instruction
>(ExtrI
)->getIterator());
199 Constant
*InsertLaneC
=
200 ConstantInt::getSigned(Type::getInt32Ty(Ctx
), InsertIdx
++);
201 // This may also return a Constant if ExtrI is a Constant.
202 auto *InsertI
= InsertElementInst::create(
203 LastInsert
, ExtrI
, InsertLaneC
, WhereIt
, Ctx
, "VPack");
204 if (!isa
<Constant
>(InsertI
)) {
205 LastInsert
= InsertI
;
206 WhereIt
= std::next(cast
<Instruction
>(LastInsert
)->getIterator());
210 Constant
*InsertLaneC
=
211 ConstantInt::getSigned(Type::getInt32Ty(Ctx
), InsertIdx
++);
212 // This may be folded into a Constant if LastInsert is a Constant. In
213 // that case we only collect the last constant.
214 LastInsert
= InsertElementInst::create(LastInsert
, Elm
, InsertLaneC
,
215 WhereIt
, Ctx
, "Pack");
216 if (auto *NewI
= dyn_cast
<Instruction
>(LastInsert
))
217 WhereIt
= std::next(NewI
->getIterator());
223 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef
<Value
*> Bndl
) {
224 for (Value
*V
: Bndl
)
225 DeadInstrCandidates
.insert(cast
<Instruction
>(V
));
226 // Also collect the GEPs of vectorized loads and stores.
227 auto Opcode
= cast
<Instruction
>(Bndl
[0])->getOpcode();
229 case Instruction::Opcode::Load
: {
230 for (Value
*V
: drop_begin(Bndl
))
232 dyn_cast
<Instruction
>(cast
<LoadInst
>(V
)->getPointerOperand()))
233 DeadInstrCandidates
.insert(Ptr
);
236 case Instruction::Opcode::Store
: {
237 for (Value
*V
: drop_begin(Bndl
))
239 dyn_cast
<Instruction
>(cast
<StoreInst
>(V
)->getPointerOperand()))
240 DeadInstrCandidates
.insert(Ptr
);
248 Value
*BottomUpVec::vectorizeRec(ArrayRef
<Value
*> Bndl
, unsigned Depth
) {
249 Value
*NewVec
= nullptr;
250 const auto &LegalityRes
= Legality
->canVectorize(Bndl
);
251 switch (LegalityRes
.getSubclassID()) {
252 case LegalityResultID::Widen
: {
253 auto *I
= cast
<Instruction
>(Bndl
[0]);
254 SmallVector
<Value
*, 2> VecOperands
;
255 switch (I
->getOpcode()) {
256 case Instruction::Opcode::Load
:
257 // Don't recurse towards the pointer operand.
258 VecOperands
.push_back(cast
<LoadInst
>(I
)->getPointerOperand());
260 case Instruction::Opcode::Store
: {
261 // Don't recurse towards the pointer operand.
262 auto *VecOp
= vectorizeRec(getOperand(Bndl
, 0), Depth
+ 1);
263 VecOperands
.push_back(VecOp
);
264 VecOperands
.push_back(cast
<StoreInst
>(I
)->getPointerOperand());
268 // Visit all operands.
269 for (auto OpIdx
: seq
<unsigned>(I
->getNumOperands())) {
270 auto *VecOp
= vectorizeRec(getOperand(Bndl
, OpIdx
), Depth
+ 1);
271 VecOperands
.push_back(VecOp
);
275 NewVec
= createVectorInstr(Bndl
, VecOperands
);
277 // Collect any potentially dead scalar instructions, including the original
278 // scalars and pointer operands of loads/stores.
279 if (NewVec
!= nullptr)
280 collectPotentiallyDeadInstrs(Bndl
);
283 case LegalityResultID::Pack
: {
284 // If we can't vectorize the seeds then just return.
287 NewVec
= createPack(Bndl
);
294 bool BottomUpVec::tryVectorize(ArrayRef
<Value
*> Bndl
) {
295 DeadInstrCandidates
.clear();
297 vectorizeRec(Bndl
, /*Depth=*/0);
298 tryEraseDeadInstrs();
302 bool BottomUpVec::runOnFunction(Function
&F
, const Analyses
&A
) {
303 Legality
= std::make_unique
<LegalityAnalysis
>(
304 A
.getAA(), A
.getScalarEvolution(), F
.getParent()->getDataLayout(),
307 const auto &DL
= F
.getParent()->getDataLayout();
308 unsigned VecRegBits
=
309 OverrideVecRegBits
!= 0
312 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector
)
315 // TODO: Start from innermost BBs first
317 SeedCollector
SC(&BB
, A
.getScalarEvolution());
318 for (SeedBundle
&Seeds
: SC
.getStoreSeeds()) {
320 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
321 Seeds
[Seeds
.getFirstUnusedElementIdx()])),
324 auto DivideBy2
= [](unsigned Num
) {
325 auto Floor
= VecUtils::getFloorPowerOf2(Num
);
330 // Try to create the largest vector supported by the target. If it fails
331 // reduce the vector size by half.
332 for (unsigned SliceElms
= std::min(VecRegBits
/ ElmBits
,
333 Seeds
.getNumUnusedBits() / ElmBits
);
334 SliceElms
>= 2u; SliceElms
= DivideBy2(SliceElms
)) {
337 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
338 // the slice. This could be quite expensive, so we enforce a limit.
339 for (unsigned Offset
= Seeds
.getFirstUnusedElementIdx(),
341 Offset
+ 1 < OE
; Offset
+= 1) {
342 // Seeds are getting used as we vectorize, so skip them.
343 if (Seeds
.isUsed(Offset
))
349 Seeds
.getSlice(Offset
, SliceElms
* ElmBits
, !AllowNonPow2
);
350 if (SeedSlice
.empty())
353 assert(SeedSlice
.size() >= 2 && "Should have been rejected!");
355 // TODO: If vectorization succeeds, run the RegionPassManager on the
358 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
359 SmallVector
<Value
*> SeedSliceVals(SeedSlice
.begin(),
361 Change
|= tryVectorize(SeedSliceVals
);
369 } // namespace sandboxir