1 //===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass converts vector operations into scalar operations, in order
10 // to expose optimization opportunities on the individual scalar operations.
11 // It is mainly intended for targets that do not have vector units, but it
12 // may also be useful for revectorizing code to different vector widths.
14 //===----------------------------------------------------------------------===//
16 #include "llvm/Transforms/Scalar/Scalarizer.h"
17 #include "llvm/ADT/PostOrderIterator.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/Argument.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Dominators.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/InstVisitor.h"
30 #include "llvm/IR/InstrTypes.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Instructions.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/IR/Value.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/Casting.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Transforms/Scalar.h"
44 #include "llvm/Transforms/Utils/Local.h"
53 #define DEBUG_TYPE "scalarizer"
55 static cl::opt
<bool> ScalarizeVariableInsertExtract(
56 "scalarize-variable-insert-extract", cl::init(true), cl::Hidden
,
57 cl::desc("Allow the scalarizer pass to scalarize "
58 "insertelement/extractelement with variable index"));
60 // This is disabled by default because having separate loads and stores
61 // makes it more likely that the -combiner-alias-analysis limits will be
64 ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden
,
65 cl::desc("Allow the scalarizer pass to scalarize loads and store"));
69 // Used to store the scattered form of a vector.
70 using ValueVector
= SmallVector
<Value
*, 8>;
72 // Used to map a vector Value to its scattered form. We use std::map
73 // because we want iterators to persist across insertion and because the
74 // values are relatively large.
75 using ScatterMap
= std::map
<Value
*, ValueVector
>;
77 // Lists Instructions that have been replaced with scalar implementations,
78 // along with a pointer to their scattered forms.
79 using GatherList
= SmallVector
<std::pair
<Instruction
*, ValueVector
*>, 16>;
81 // Provides a very limited vector-like interface for lazily accessing one
82 // component of a scattered vector or vector pointer.
85 Scatterer() = default;
87 // Scatter V into Size components. If new instructions are needed,
88 // insert them before BBI in BB. If Cache is nonnull, use it to cache
90 Scatterer(BasicBlock
*bb
, BasicBlock::iterator bbi
, Value
*v
,
91 ValueVector
*cachePtr
= nullptr);
93 // Return component I, creating a new Value for it if necessary.
94 Value
*operator[](unsigned I
);
96 // Return the number of components.
97 unsigned size() const { return Size
; }
101 BasicBlock::iterator BBI
;
103 ValueVector
*CachePtr
;
109 // FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
110 // called Name that compares X and Y in the same way as FCI.
111 struct FCmpSplitter
{
112 FCmpSplitter(FCmpInst
&fci
) : FCI(fci
) {}
114 Value
*operator()(IRBuilder
<> &Builder
, Value
*Op0
, Value
*Op1
,
115 const Twine
&Name
) const {
116 return Builder
.CreateFCmp(FCI
.getPredicate(), Op0
, Op1
, Name
);
122 // ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
123 // called Name that compares X and Y in the same way as ICI.
124 struct ICmpSplitter
{
125 ICmpSplitter(ICmpInst
&ici
) : ICI(ici
) {}
127 Value
*operator()(IRBuilder
<> &Builder
, Value
*Op0
, Value
*Op1
,
128 const Twine
&Name
) const {
129 return Builder
.CreateICmp(ICI
.getPredicate(), Op0
, Op1
, Name
);
135 // UnarySpliiter(UO)(Builder, X, Name) uses Builder to create
136 // a unary operator like UO called Name with operand X.
137 struct UnarySplitter
{
138 UnarySplitter(UnaryOperator
&uo
) : UO(uo
) {}
140 Value
*operator()(IRBuilder
<> &Builder
, Value
*Op
, const Twine
&Name
) const {
141 return Builder
.CreateUnOp(UO
.getOpcode(), Op
, Name
);
147 // BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create
148 // a binary operator like BO called Name with operands X and Y.
149 struct BinarySplitter
{
150 BinarySplitter(BinaryOperator
&bo
) : BO(bo
) {}
152 Value
*operator()(IRBuilder
<> &Builder
, Value
*Op0
, Value
*Op1
,
153 const Twine
&Name
) const {
154 return Builder
.CreateBinOp(BO
.getOpcode(), Op0
, Op1
, Name
);
160 // Information about a load or store that we're scalarizing.
161 struct VectorLayout
{
162 VectorLayout() = default;
164 // Return the alignment of element I.
165 Align
getElemAlign(unsigned I
) {
166 return commonAlignment(VecAlign
, I
* ElemSize
);
169 // The type of the vector.
170 VectorType
*VecTy
= nullptr;
172 // The type of each element.
173 Type
*ElemTy
= nullptr;
175 // The alignment of the vector.
178 // The size of each element.
179 uint64_t ElemSize
= 0;
182 class ScalarizerVisitor
: public InstVisitor
<ScalarizerVisitor
, bool> {
184 ScalarizerVisitor(unsigned ParallelLoopAccessMDKind
, DominatorTree
*DT
)
185 : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind
), DT(DT
) {
188 bool visit(Function
&F
);
190 // InstVisitor methods. They return true if the instruction was scalarized,
191 // false if nothing changed.
192 bool visitInstruction(Instruction
&I
) { return false; }
193 bool visitSelectInst(SelectInst
&SI
);
194 bool visitICmpInst(ICmpInst
&ICI
);
195 bool visitFCmpInst(FCmpInst
&FCI
);
196 bool visitUnaryOperator(UnaryOperator
&UO
);
197 bool visitBinaryOperator(BinaryOperator
&BO
);
198 bool visitGetElementPtrInst(GetElementPtrInst
&GEPI
);
199 bool visitCastInst(CastInst
&CI
);
200 bool visitBitCastInst(BitCastInst
&BCI
);
201 bool visitInsertElementInst(InsertElementInst
&IEI
);
202 bool visitExtractElementInst(ExtractElementInst
&EEI
);
203 bool visitShuffleVectorInst(ShuffleVectorInst
&SVI
);
204 bool visitPHINode(PHINode
&PHI
);
205 bool visitLoadInst(LoadInst
&LI
);
206 bool visitStoreInst(StoreInst
&SI
);
207 bool visitCallInst(CallInst
&ICI
);
210 Scatterer
scatter(Instruction
*Point
, Value
*V
);
211 void gather(Instruction
*Op
, const ValueVector
&CV
);
212 bool canTransferMetadata(unsigned Kind
);
213 void transferMetadataAndIRFlags(Instruction
*Op
, const ValueVector
&CV
);
214 Optional
<VectorLayout
> getVectorLayout(Type
*Ty
, Align Alignment
,
215 const DataLayout
&DL
);
218 template<typename T
> bool splitUnary(Instruction
&, const T
&);
219 template<typename T
> bool splitBinary(Instruction
&, const T
&);
221 bool splitCall(CallInst
&CI
);
223 ScatterMap Scattered
;
226 SmallVector
<WeakTrackingVH
, 32> PotentiallyDeadInstrs
;
228 unsigned ParallelLoopAccessMDKind
;
233 class ScalarizerLegacyPass
: public FunctionPass
{
237 ScalarizerLegacyPass() : FunctionPass(ID
) {
238 initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry());
241 bool runOnFunction(Function
&F
) override
;
243 void getAnalysisUsage(AnalysisUsage
& AU
) const override
{
244 AU
.addRequired
<DominatorTreeWrapperPass
>();
245 AU
.addPreserved
<DominatorTreeWrapperPass
>();
249 } // end anonymous namespace
251 char ScalarizerLegacyPass::ID
= 0;
252 INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass
, "scalarizer",
253 "Scalarize vector operations", false, false)
254 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
255 INITIALIZE_PASS_END(ScalarizerLegacyPass
, "scalarizer",
256 "Scalarize vector operations", false, false)
258 Scatterer::Scatterer(BasicBlock
*bb
, BasicBlock::iterator bbi
, Value
*v
,
259 ValueVector
*cachePtr
)
260 : BB(bb
), BBI(bbi
), V(v
), CachePtr(cachePtr
) {
261 Type
*Ty
= V
->getType();
262 PtrTy
= dyn_cast
<PointerType
>(Ty
);
264 Ty
= PtrTy
->getElementType();
265 Size
= cast
<FixedVectorType
>(Ty
)->getNumElements();
267 Tmp
.resize(Size
, nullptr);
268 else if (CachePtr
->empty())
269 CachePtr
->resize(Size
, nullptr);
271 assert(Size
== CachePtr
->size() && "Inconsistent vector sizes");
274 // Return component I, creating a new Value for it if necessary.
275 Value
*Scatterer::operator[](unsigned I
) {
276 ValueVector
&CV
= (CachePtr
? *CachePtr
: Tmp
);
277 // Try to reuse a previous value.
280 IRBuilder
<> Builder(BB
, BBI
);
282 Type
*ElTy
= cast
<VectorType
>(PtrTy
->getElementType())->getElementType();
284 Type
*NewPtrTy
= PointerType::get(ElTy
, PtrTy
->getAddressSpace());
285 CV
[0] = Builder
.CreateBitCast(V
, NewPtrTy
, V
->getName() + ".i0");
288 CV
[I
] = Builder
.CreateConstGEP1_32(ElTy
, CV
[0], I
,
289 V
->getName() + ".i" + Twine(I
));
291 // Search through a chain of InsertElementInsts looking for element I.
292 // Record other elements in the cache. The new V is still suitable
293 // for all uncached indices.
295 InsertElementInst
*Insert
= dyn_cast
<InsertElementInst
>(V
);
298 ConstantInt
*Idx
= dyn_cast
<ConstantInt
>(Insert
->getOperand(2));
301 unsigned J
= Idx
->getZExtValue();
302 V
= Insert
->getOperand(0);
304 CV
[J
] = Insert
->getOperand(1);
307 // Only cache the first entry we find for each index we're not actively
308 // searching for. This prevents us from going too far up the chain and
309 // caching incorrect entries.
310 CV
[J
] = Insert
->getOperand(1);
313 CV
[I
] = Builder
.CreateExtractElement(V
, Builder
.getInt32(I
),
314 V
->getName() + ".i" + Twine(I
));
319 bool ScalarizerLegacyPass::runOnFunction(Function
&F
) {
323 Module
&M
= *F
.getParent();
324 unsigned ParallelLoopAccessMDKind
=
325 M
.getContext().getMDKindID("llvm.mem.parallel_loop_access");
326 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
327 ScalarizerVisitor
Impl(ParallelLoopAccessMDKind
, DT
);
328 return Impl
.visit(F
);
331 FunctionPass
*llvm::createScalarizerPass() {
332 return new ScalarizerLegacyPass();
335 bool ScalarizerVisitor::visit(Function
&F
) {
336 assert(Gathered
.empty() && Scattered
.empty());
338 // To ensure we replace gathered components correctly we need to do an ordered
339 // traversal of the basic blocks in the function.
340 ReversePostOrderTraversal
<BasicBlock
*> RPOT(&F
.getEntryBlock());
341 for (BasicBlock
*BB
: RPOT
) {
342 for (BasicBlock::iterator II
= BB
->begin(), IE
= BB
->end(); II
!= IE
;) {
343 Instruction
*I
= &*II
;
344 bool Done
= InstVisitor::visit(I
);
346 if (Done
&& I
->getType()->isVoidTy())
347 I
->eraseFromParent();
353 // Return a scattered form of V that can be accessed by Point. V must be a
354 // vector or a pointer to a vector.
355 Scatterer
ScalarizerVisitor::scatter(Instruction
*Point
, Value
*V
) {
356 if (Argument
*VArg
= dyn_cast
<Argument
>(V
)) {
357 // Put the scattered form of arguments in the entry block,
358 // so that it can be used everywhere.
359 Function
*F
= VArg
->getParent();
360 BasicBlock
*BB
= &F
->getEntryBlock();
361 return Scatterer(BB
, BB
->begin(), V
, &Scattered
[V
]);
363 if (Instruction
*VOp
= dyn_cast
<Instruction
>(V
)) {
364 // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
365 // nodes in predecessors. If those predecessors are unreachable from entry,
366 // then the IR in those blocks could have unexpected properties resulting in
367 // infinite loops in Scatterer::operator[]. By simply treating values
368 // originating from instructions in unreachable blocks as undef we do not
369 // need to analyse them further.
370 if (!DT
->isReachableFromEntry(VOp
->getParent()))
371 return Scatterer(Point
->getParent(), Point
->getIterator(),
372 UndefValue::get(V
->getType()));
373 // Put the scattered form of an instruction directly after the
375 BasicBlock
*BB
= VOp
->getParent();
376 return Scatterer(BB
, std::next(BasicBlock::iterator(VOp
)),
379 // In the fallback case, just put the scattered before Point and
380 // keep the result local to Point.
381 return Scatterer(Point
->getParent(), Point
->getIterator(), V
);
384 // Replace Op with the gathered form of the components in CV. Defer the
385 // deletion of Op and creation of the gathered form to the end of the pass,
386 // so that we can avoid creating the gathered form if all uses of Op are
387 // replaced with uses of CV.
388 void ScalarizerVisitor::gather(Instruction
*Op
, const ValueVector
&CV
) {
389 transferMetadataAndIRFlags(Op
, CV
);
391 // If we already have a scattered form of Op (created from ExtractElements
392 // of Op itself), replace them with the new form.
393 ValueVector
&SV
= Scattered
[Op
];
395 for (unsigned I
= 0, E
= SV
.size(); I
!= E
; ++I
) {
397 if (V
== nullptr || SV
[I
] == CV
[I
])
400 Instruction
*Old
= cast
<Instruction
>(V
);
401 if (isa
<Instruction
>(CV
[I
]))
402 CV
[I
]->takeName(Old
);
403 Old
->replaceAllUsesWith(CV
[I
]);
404 PotentiallyDeadInstrs
.emplace_back(Old
);
408 Gathered
.push_back(GatherList::value_type(Op
, &SV
));
411 // Return true if it is safe to transfer the given metadata tag from
412 // vector to scalar instructions.
413 bool ScalarizerVisitor::canTransferMetadata(unsigned Tag
) {
414 return (Tag
== LLVMContext::MD_tbaa
415 || Tag
== LLVMContext::MD_fpmath
416 || Tag
== LLVMContext::MD_tbaa_struct
417 || Tag
== LLVMContext::MD_invariant_load
418 || Tag
== LLVMContext::MD_alias_scope
419 || Tag
== LLVMContext::MD_noalias
420 || Tag
== ParallelLoopAccessMDKind
421 || Tag
== LLVMContext::MD_access_group
);
424 // Transfer metadata from Op to the instructions in CV if it is known
425 // to be safe to do so.
426 void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction
*Op
,
427 const ValueVector
&CV
) {
428 SmallVector
<std::pair
<unsigned, MDNode
*>, 4> MDs
;
429 Op
->getAllMetadataOtherThanDebugLoc(MDs
);
430 for (unsigned I
= 0, E
= CV
.size(); I
!= E
; ++I
) {
431 if (Instruction
*New
= dyn_cast
<Instruction
>(CV
[I
])) {
432 for (const auto &MD
: MDs
)
433 if (canTransferMetadata(MD
.first
))
434 New
->setMetadata(MD
.first
, MD
.second
);
435 New
->copyIRFlags(Op
);
436 if (Op
->getDebugLoc() && !New
->getDebugLoc())
437 New
->setDebugLoc(Op
->getDebugLoc());
442 // Try to fill in Layout from Ty, returning true on success. Alignment is
443 // the alignment of the vector, or None if the ABI default should be used.
444 Optional
<VectorLayout
>
445 ScalarizerVisitor::getVectorLayout(Type
*Ty
, Align Alignment
,
446 const DataLayout
&DL
) {
448 // Make sure we're dealing with a vector.
449 Layout
.VecTy
= dyn_cast
<VectorType
>(Ty
);
452 // Check that we're dealing with full-byte elements.
453 Layout
.ElemTy
= Layout
.VecTy
->getElementType();
454 if (!DL
.typeSizeEqualsStoreSize(Layout
.ElemTy
))
456 Layout
.VecAlign
= Alignment
;
457 Layout
.ElemSize
= DL
.getTypeStoreSize(Layout
.ElemTy
);
461 // Scalarize one-operand instruction I, using Split(Builder, X, Name)
462 // to create an instruction like I with operand X and name Name.
463 template<typename Splitter
>
464 bool ScalarizerVisitor::splitUnary(Instruction
&I
, const Splitter
&Split
) {
465 VectorType
*VT
= dyn_cast
<VectorType
>(I
.getType());
469 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
470 IRBuilder
<> Builder(&I
);
471 Scatterer Op
= scatter(&I
, I
.getOperand(0));
472 assert(Op
.size() == NumElems
&& "Mismatched unary operation");
474 Res
.resize(NumElems
);
475 for (unsigned Elem
= 0; Elem
< NumElems
; ++Elem
)
476 Res
[Elem
] = Split(Builder
, Op
[Elem
], I
.getName() + ".i" + Twine(Elem
));
481 // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
482 // to create an instruction like I with operands X and Y and name Name.
483 template<typename Splitter
>
484 bool ScalarizerVisitor::splitBinary(Instruction
&I
, const Splitter
&Split
) {
485 VectorType
*VT
= dyn_cast
<VectorType
>(I
.getType());
489 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
490 IRBuilder
<> Builder(&I
);
491 Scatterer VOp0
= scatter(&I
, I
.getOperand(0));
492 Scatterer VOp1
= scatter(&I
, I
.getOperand(1));
493 assert(VOp0
.size() == NumElems
&& "Mismatched binary operation");
494 assert(VOp1
.size() == NumElems
&& "Mismatched binary operation");
496 Res
.resize(NumElems
);
497 for (unsigned Elem
= 0; Elem
< NumElems
; ++Elem
) {
498 Value
*Op0
= VOp0
[Elem
];
499 Value
*Op1
= VOp1
[Elem
];
500 Res
[Elem
] = Split(Builder
, Op0
, Op1
, I
.getName() + ".i" + Twine(Elem
));
506 static bool isTriviallyScalariable(Intrinsic::ID ID
) {
507 return isTriviallyVectorizable(ID
);
510 // All of the current scalarizable intrinsics only have one mangled type.
511 static Function
*getScalarIntrinsicDeclaration(Module
*M
,
513 ArrayRef
<Type
*> Tys
) {
514 return Intrinsic::getDeclaration(M
, ID
, Tys
);
517 /// If a call to a vector typed intrinsic function, split into a scalar call per
518 /// element if possible for the intrinsic.
519 bool ScalarizerVisitor::splitCall(CallInst
&CI
) {
520 VectorType
*VT
= dyn_cast
<VectorType
>(CI
.getType());
524 Function
*F
= CI
.getCalledFunction();
528 Intrinsic::ID ID
= F
->getIntrinsicID();
529 if (ID
== Intrinsic::not_intrinsic
|| !isTriviallyScalariable(ID
))
532 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
533 unsigned NumArgs
= CI
.getNumArgOperands();
535 ValueVector
ScalarOperands(NumArgs
);
536 SmallVector
<Scatterer
, 8> Scattered(NumArgs
);
538 Scattered
.resize(NumArgs
);
540 SmallVector
<llvm::Type
*, 3> Tys
;
541 Tys
.push_back(VT
->getScalarType());
543 // Assumes that any vector type has the same number of elements as the return
544 // vector type, which is true for all current intrinsics.
545 for (unsigned I
= 0; I
!= NumArgs
; ++I
) {
546 Value
*OpI
= CI
.getOperand(I
);
547 if (OpI
->getType()->isVectorTy()) {
548 Scattered
[I
] = scatter(&CI
, OpI
);
549 assert(Scattered
[I
].size() == NumElems
&& "mismatched call operands");
551 ScalarOperands
[I
] = OpI
;
552 if (hasVectorInstrinsicOverloadedScalarOpd(ID
, I
))
553 Tys
.push_back(OpI
->getType());
557 ValueVector
Res(NumElems
);
558 ValueVector
ScalarCallOps(NumArgs
);
560 Function
*NewIntrin
= getScalarIntrinsicDeclaration(F
->getParent(), ID
, Tys
);
561 IRBuilder
<> Builder(&CI
);
563 // Perform actual scalarization, taking care to preserve any scalar operands.
564 for (unsigned Elem
= 0; Elem
< NumElems
; ++Elem
) {
565 ScalarCallOps
.clear();
567 for (unsigned J
= 0; J
!= NumArgs
; ++J
) {
568 if (hasVectorInstrinsicScalarOpd(ID
, J
))
569 ScalarCallOps
.push_back(ScalarOperands
[J
]);
571 ScalarCallOps
.push_back(Scattered
[J
][Elem
]);
574 Res
[Elem
] = Builder
.CreateCall(NewIntrin
, ScalarCallOps
,
575 CI
.getName() + ".i" + Twine(Elem
));
582 bool ScalarizerVisitor::visitSelectInst(SelectInst
&SI
) {
583 VectorType
*VT
= dyn_cast
<VectorType
>(SI
.getType());
587 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
588 IRBuilder
<> Builder(&SI
);
589 Scatterer VOp1
= scatter(&SI
, SI
.getOperand(1));
590 Scatterer VOp2
= scatter(&SI
, SI
.getOperand(2));
591 assert(VOp1
.size() == NumElems
&& "Mismatched select");
592 assert(VOp2
.size() == NumElems
&& "Mismatched select");
594 Res
.resize(NumElems
);
596 if (SI
.getOperand(0)->getType()->isVectorTy()) {
597 Scatterer VOp0
= scatter(&SI
, SI
.getOperand(0));
598 assert(VOp0
.size() == NumElems
&& "Mismatched select");
599 for (unsigned I
= 0; I
< NumElems
; ++I
) {
600 Value
*Op0
= VOp0
[I
];
601 Value
*Op1
= VOp1
[I
];
602 Value
*Op2
= VOp2
[I
];
603 Res
[I
] = Builder
.CreateSelect(Op0
, Op1
, Op2
,
604 SI
.getName() + ".i" + Twine(I
));
607 Value
*Op0
= SI
.getOperand(0);
608 for (unsigned I
= 0; I
< NumElems
; ++I
) {
609 Value
*Op1
= VOp1
[I
];
610 Value
*Op2
= VOp2
[I
];
611 Res
[I
] = Builder
.CreateSelect(Op0
, Op1
, Op2
,
612 SI
.getName() + ".i" + Twine(I
));
619 bool ScalarizerVisitor::visitICmpInst(ICmpInst
&ICI
) {
620 return splitBinary(ICI
, ICmpSplitter(ICI
));
623 bool ScalarizerVisitor::visitFCmpInst(FCmpInst
&FCI
) {
624 return splitBinary(FCI
, FCmpSplitter(FCI
));
627 bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator
&UO
) {
628 return splitUnary(UO
, UnarySplitter(UO
));
631 bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator
&BO
) {
632 return splitBinary(BO
, BinarySplitter(BO
));
635 bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst
&GEPI
) {
636 VectorType
*VT
= dyn_cast
<VectorType
>(GEPI
.getType());
640 IRBuilder
<> Builder(&GEPI
);
641 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
642 unsigned NumIndices
= GEPI
.getNumIndices();
644 // The base pointer might be scalar even if it's a vector GEP. In those cases,
645 // splat the pointer into a vector value, and scatter that vector.
646 Value
*Op0
= GEPI
.getOperand(0);
647 if (!Op0
->getType()->isVectorTy())
648 Op0
= Builder
.CreateVectorSplat(NumElems
, Op0
);
649 Scatterer Base
= scatter(&GEPI
, Op0
);
651 SmallVector
<Scatterer
, 8> Ops
;
652 Ops
.resize(NumIndices
);
653 for (unsigned I
= 0; I
< NumIndices
; ++I
) {
654 Value
*Op
= GEPI
.getOperand(I
+ 1);
656 // The indices might be scalars even if it's a vector GEP. In those cases,
657 // splat the scalar into a vector value, and scatter that vector.
658 if (!Op
->getType()->isVectorTy())
659 Op
= Builder
.CreateVectorSplat(NumElems
, Op
);
661 Ops
[I
] = scatter(&GEPI
, Op
);
665 Res
.resize(NumElems
);
666 for (unsigned I
= 0; I
< NumElems
; ++I
) {
667 SmallVector
<Value
*, 8> Indices
;
668 Indices
.resize(NumIndices
);
669 for (unsigned J
= 0; J
< NumIndices
; ++J
)
670 Indices
[J
] = Ops
[J
][I
];
671 Res
[I
] = Builder
.CreateGEP(GEPI
.getSourceElementType(), Base
[I
], Indices
,
672 GEPI
.getName() + ".i" + Twine(I
));
673 if (GEPI
.isInBounds())
674 if (GetElementPtrInst
*NewGEPI
= dyn_cast
<GetElementPtrInst
>(Res
[I
]))
675 NewGEPI
->setIsInBounds();
681 bool ScalarizerVisitor::visitCastInst(CastInst
&CI
) {
682 VectorType
*VT
= dyn_cast
<VectorType
>(CI
.getDestTy());
686 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
687 IRBuilder
<> Builder(&CI
);
688 Scatterer Op0
= scatter(&CI
, CI
.getOperand(0));
689 assert(Op0
.size() == NumElems
&& "Mismatched cast");
691 Res
.resize(NumElems
);
692 for (unsigned I
= 0; I
< NumElems
; ++I
)
693 Res
[I
] = Builder
.CreateCast(CI
.getOpcode(), Op0
[I
], VT
->getElementType(),
694 CI
.getName() + ".i" + Twine(I
));
699 bool ScalarizerVisitor::visitBitCastInst(BitCastInst
&BCI
) {
700 VectorType
*DstVT
= dyn_cast
<VectorType
>(BCI
.getDestTy());
701 VectorType
*SrcVT
= dyn_cast
<VectorType
>(BCI
.getSrcTy());
702 if (!DstVT
|| !SrcVT
)
705 unsigned DstNumElems
= cast
<FixedVectorType
>(DstVT
)->getNumElements();
706 unsigned SrcNumElems
= cast
<FixedVectorType
>(SrcVT
)->getNumElements();
707 IRBuilder
<> Builder(&BCI
);
708 Scatterer Op0
= scatter(&BCI
, BCI
.getOperand(0));
710 Res
.resize(DstNumElems
);
712 if (DstNumElems
== SrcNumElems
) {
713 for (unsigned I
= 0; I
< DstNumElems
; ++I
)
714 Res
[I
] = Builder
.CreateBitCast(Op0
[I
], DstVT
->getElementType(),
715 BCI
.getName() + ".i" + Twine(I
));
716 } else if (DstNumElems
> SrcNumElems
) {
717 // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the
718 // individual elements to the destination.
719 unsigned FanOut
= DstNumElems
/ SrcNumElems
;
720 auto *MidTy
= FixedVectorType::get(DstVT
->getElementType(), FanOut
);
722 for (unsigned Op0I
= 0; Op0I
< SrcNumElems
; ++Op0I
) {
723 Value
*V
= Op0
[Op0I
];
725 // Look through any existing bitcasts before converting to <N x t2>.
726 // In the best case, the resulting conversion might be a no-op.
727 while ((VI
= dyn_cast
<Instruction
>(V
)) &&
728 VI
->getOpcode() == Instruction::BitCast
)
729 V
= VI
->getOperand(0);
730 V
= Builder
.CreateBitCast(V
, MidTy
, V
->getName() + ".cast");
731 Scatterer Mid
= scatter(&BCI
, V
);
732 for (unsigned MidI
= 0; MidI
< FanOut
; ++MidI
)
733 Res
[ResI
++] = Mid
[MidI
];
736 // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2.
737 unsigned FanIn
= SrcNumElems
/ DstNumElems
;
738 auto *MidTy
= FixedVectorType::get(SrcVT
->getElementType(), FanIn
);
740 for (unsigned ResI
= 0; ResI
< DstNumElems
; ++ResI
) {
741 Value
*V
= PoisonValue::get(MidTy
);
742 for (unsigned MidI
= 0; MidI
< FanIn
; ++MidI
)
743 V
= Builder
.CreateInsertElement(V
, Op0
[Op0I
++], Builder
.getInt32(MidI
),
744 BCI
.getName() + ".i" + Twine(ResI
)
745 + ".upto" + Twine(MidI
));
746 Res
[ResI
] = Builder
.CreateBitCast(V
, DstVT
->getElementType(),
747 BCI
.getName() + ".i" + Twine(ResI
));
754 bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst
&IEI
) {
755 VectorType
*VT
= dyn_cast
<VectorType
>(IEI
.getType());
759 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
760 IRBuilder
<> Builder(&IEI
);
761 Scatterer Op0
= scatter(&IEI
, IEI
.getOperand(0));
762 Value
*NewElt
= IEI
.getOperand(1);
763 Value
*InsIdx
= IEI
.getOperand(2);
766 Res
.resize(NumElems
);
768 if (auto *CI
= dyn_cast
<ConstantInt
>(InsIdx
)) {
769 for (unsigned I
= 0; I
< NumElems
; ++I
)
770 Res
[I
] = CI
->getValue().getZExtValue() == I
? NewElt
: Op0
[I
];
772 if (!ScalarizeVariableInsertExtract
)
775 for (unsigned I
= 0; I
< NumElems
; ++I
) {
776 Value
*ShouldReplace
=
777 Builder
.CreateICmpEQ(InsIdx
, ConstantInt::get(InsIdx
->getType(), I
),
778 InsIdx
->getName() + ".is." + Twine(I
));
779 Value
*OldElt
= Op0
[I
];
780 Res
[I
] = Builder
.CreateSelect(ShouldReplace
, NewElt
, OldElt
,
781 IEI
.getName() + ".i" + Twine(I
));
789 bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst
&EEI
) {
790 VectorType
*VT
= dyn_cast
<VectorType
>(EEI
.getOperand(0)->getType());
794 unsigned NumSrcElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
795 IRBuilder
<> Builder(&EEI
);
796 Scatterer Op0
= scatter(&EEI
, EEI
.getOperand(0));
797 Value
*ExtIdx
= EEI
.getOperand(1);
799 if (auto *CI
= dyn_cast
<ConstantInt
>(ExtIdx
)) {
800 Value
*Res
= Op0
[CI
->getValue().getZExtValue()];
805 if (!ScalarizeVariableInsertExtract
)
808 Value
*Res
= UndefValue::get(VT
->getElementType());
809 for (unsigned I
= 0; I
< NumSrcElems
; ++I
) {
810 Value
*ShouldExtract
=
811 Builder
.CreateICmpEQ(ExtIdx
, ConstantInt::get(ExtIdx
->getType(), I
),
812 ExtIdx
->getName() + ".is." + Twine(I
));
814 Res
= Builder
.CreateSelect(ShouldExtract
, Elt
, Res
,
815 EEI
.getName() + ".upto" + Twine(I
));
821 bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst
&SVI
) {
822 VectorType
*VT
= dyn_cast
<VectorType
>(SVI
.getType());
826 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
827 Scatterer Op0
= scatter(&SVI
, SVI
.getOperand(0));
828 Scatterer Op1
= scatter(&SVI
, SVI
.getOperand(1));
830 Res
.resize(NumElems
);
832 for (unsigned I
= 0; I
< NumElems
; ++I
) {
833 int Selector
= SVI
.getMaskValue(I
);
835 Res
[I
] = UndefValue::get(VT
->getElementType());
836 else if (unsigned(Selector
) < Op0
.size())
837 Res
[I
] = Op0
[Selector
];
839 Res
[I
] = Op1
[Selector
- Op0
.size()];
845 bool ScalarizerVisitor::visitPHINode(PHINode
&PHI
) {
846 VectorType
*VT
= dyn_cast
<VectorType
>(PHI
.getType());
850 unsigned NumElems
= cast
<FixedVectorType
>(VT
)->getNumElements();
851 IRBuilder
<> Builder(&PHI
);
853 Res
.resize(NumElems
);
855 unsigned NumOps
= PHI
.getNumOperands();
856 for (unsigned I
= 0; I
< NumElems
; ++I
)
857 Res
[I
] = Builder
.CreatePHI(VT
->getElementType(), NumOps
,
858 PHI
.getName() + ".i" + Twine(I
));
860 for (unsigned I
= 0; I
< NumOps
; ++I
) {
861 Scatterer Op
= scatter(&PHI
, PHI
.getIncomingValue(I
));
862 BasicBlock
*IncomingBlock
= PHI
.getIncomingBlock(I
);
863 for (unsigned J
= 0; J
< NumElems
; ++J
)
864 cast
<PHINode
>(Res
[J
])->addIncoming(Op
[J
], IncomingBlock
);
870 bool ScalarizerVisitor::visitLoadInst(LoadInst
&LI
) {
871 if (!ScalarizeLoadStore
)
876 Optional
<VectorLayout
> Layout
= getVectorLayout(
877 LI
.getType(), LI
.getAlign(), LI
.getModule()->getDataLayout());
881 unsigned NumElems
= cast
<FixedVectorType
>(Layout
->VecTy
)->getNumElements();
882 IRBuilder
<> Builder(&LI
);
883 Scatterer Ptr
= scatter(&LI
, LI
.getPointerOperand());
885 Res
.resize(NumElems
);
887 for (unsigned I
= 0; I
< NumElems
; ++I
)
888 Res
[I
] = Builder
.CreateAlignedLoad(Layout
->VecTy
->getElementType(), Ptr
[I
],
889 Align(Layout
->getElemAlign(I
)),
890 LI
.getName() + ".i" + Twine(I
));
895 bool ScalarizerVisitor::visitStoreInst(StoreInst
&SI
) {
896 if (!ScalarizeLoadStore
)
901 Value
*FullValue
= SI
.getValueOperand();
902 Optional
<VectorLayout
> Layout
= getVectorLayout(
903 FullValue
->getType(), SI
.getAlign(), SI
.getModule()->getDataLayout());
907 unsigned NumElems
= cast
<FixedVectorType
>(Layout
->VecTy
)->getNumElements();
908 IRBuilder
<> Builder(&SI
);
909 Scatterer VPtr
= scatter(&SI
, SI
.getPointerOperand());
910 Scatterer VVal
= scatter(&SI
, FullValue
);
913 Stores
.resize(NumElems
);
914 for (unsigned I
= 0; I
< NumElems
; ++I
) {
915 Value
*Val
= VVal
[I
];
916 Value
*Ptr
= VPtr
[I
];
917 Stores
[I
] = Builder
.CreateAlignedStore(Val
, Ptr
, Layout
->getElemAlign(I
));
919 transferMetadataAndIRFlags(&SI
, Stores
);
923 bool ScalarizerVisitor::visitCallInst(CallInst
&CI
) {
924 return splitCall(CI
);
927 // Delete the instructions that we scalarized. If a full vector result
928 // is still needed, recreate it using InsertElements.
929 bool ScalarizerVisitor::finish() {
930 // The presence of data in Gathered or Scattered indicates changes
931 // made to the Function.
932 if (Gathered
.empty() && Scattered
.empty())
934 for (const auto &GMI
: Gathered
) {
935 Instruction
*Op
= GMI
.first
;
936 ValueVector
&CV
= *GMI
.second
;
937 if (!Op
->use_empty()) {
938 // The value is still needed, so recreate it using a series of
940 Value
*Res
= PoisonValue::get(Op
->getType());
941 if (auto *Ty
= dyn_cast
<VectorType
>(Op
->getType())) {
942 BasicBlock
*BB
= Op
->getParent();
943 unsigned Count
= cast
<FixedVectorType
>(Ty
)->getNumElements();
944 IRBuilder
<> Builder(Op
);
945 if (isa
<PHINode
>(Op
))
946 Builder
.SetInsertPoint(BB
, BB
->getFirstInsertionPt());
947 for (unsigned I
= 0; I
< Count
; ++I
)
948 Res
= Builder
.CreateInsertElement(Res
, CV
[I
], Builder
.getInt32(I
),
949 Op
->getName() + ".upto" + Twine(I
));
952 assert(CV
.size() == 1 && Op
->getType() == CV
[0]->getType());
957 Op
->replaceAllUsesWith(Res
);
959 PotentiallyDeadInstrs
.emplace_back(Op
);
964 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs
);
969 PreservedAnalyses
ScalarizerPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
970 Module
&M
= *F
.getParent();
971 unsigned ParallelLoopAccessMDKind
=
972 M
.getContext().getMDKindID("llvm.mem.parallel_loop_access");
973 DominatorTree
*DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
974 ScalarizerVisitor
Impl(ParallelLoopAccessMDKind
, DT
);
975 bool Changed
= Impl
.visit(F
);
976 PreservedAnalyses PA
;
977 PA
.preserve
<DominatorTreeAnalysis
>();
978 return Changed
? PA
: PreservedAnalyses::all();