1 //===- HexagonLoopIdiomRecognition.cpp ------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "HexagonLoopIdiomRecognition.h"
10 #include "llvm/ADT/APInt.h"
11 #include "llvm/ADT/DenseMap.h"
12 #include "llvm/ADT/SetVector.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/ADT/Triple.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/InstructionSimplify.h"
20 #include "llvm/Analysis/LoopAnalysisManager.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/LoopPass.h"
23 #include "llvm/Analysis/MemoryLocation.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
26 #include "llvm/Analysis/TargetLibraryInfo.h"
27 #include "llvm/Analysis/ValueTracking.h"
28 #include "llvm/IR/Attributes.h"
29 #include "llvm/IR/BasicBlock.h"
30 #include "llvm/IR/Constant.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/DebugLoc.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/Dominators.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/IRBuilder.h"
38 #include "llvm/IR/InstrTypes.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/Intrinsics.h"
43 #include "llvm/IR/IntrinsicsHexagon.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/IR/PassManager.h"
46 #include "llvm/IR/PatternMatch.h"
47 #include "llvm/IR/Type.h"
48 #include "llvm/IR/User.h"
49 #include "llvm/IR/Value.h"
50 #include "llvm/InitializePasses.h"
51 #include "llvm/Pass.h"
52 #include "llvm/Support/Casting.h"
53 #include "llvm/Support/CommandLine.h"
54 #include "llvm/Support/Compiler.h"
55 #include "llvm/Support/Debug.h"
56 #include "llvm/Support/ErrorHandling.h"
57 #include "llvm/Support/KnownBits.h"
58 #include "llvm/Support/raw_ostream.h"
59 #include "llvm/Transforms/Scalar.h"
60 #include "llvm/Transforms/Utils.h"
61 #include "llvm/Transforms/Utils/Local.h"
62 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
76 #define DEBUG_TYPE "hexagon-lir"
80 static cl::opt
<bool> DisableMemcpyIdiom("disable-memcpy-idiom",
81 cl::Hidden
, cl::init(false),
82 cl::desc("Disable generation of memcpy in loop idiom recognition"));
84 static cl::opt
<bool> DisableMemmoveIdiom("disable-memmove-idiom",
85 cl::Hidden
, cl::init(false),
86 cl::desc("Disable generation of memmove in loop idiom recognition"));
88 static cl::opt
<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold",
89 cl::Hidden
, cl::init(0), cl::desc("Threshold (in bytes) for the runtime "
90 "check guarding the memmove."));
92 static cl::opt
<unsigned> CompileTimeMemSizeThreshold(
93 "compile-time-mem-idiom-threshold", cl::Hidden
, cl::init(64),
94 cl::desc("Threshold (in bytes) to perform the transformation, if the "
95 "runtime loop count (mem transfer size) is known at compile-time."));
97 static cl::opt
<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom",
98 cl::Hidden
, cl::init(true),
99 cl::desc("Only enable generating memmove in non-nested loops"));
101 static cl::opt
<bool> HexagonVolatileMemcpy(
102 "disable-hexagon-volatile-memcpy", cl::Hidden
, cl::init(false),
103 cl::desc("Enable Hexagon-specific memcpy for volatile destination."));
105 static cl::opt
<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000),
106 cl::Hidden
, cl::desc("Maximum number of simplification steps in HLIR"));
108 static const char *HexagonVolatileMemcpyName
109 = "hexagon_memcpy_forward_vp4cp4n2";
114 void initializeHexagonLoopIdiomRecognizeLegacyPassPass(PassRegistry
&);
115 Pass
*createHexagonLoopIdiomPass();
117 } // end namespace llvm
121 class HexagonLoopIdiomRecognize
{
123 explicit HexagonLoopIdiomRecognize(AliasAnalysis
*AA
, DominatorTree
*DT
,
124 LoopInfo
*LF
, const TargetLibraryInfo
*TLI
,
126 : AA(AA
), DT(DT
), LF(LF
), TLI(TLI
), SE(SE
) {}
131 int getSCEVStride(const SCEVAddRecExpr
*StoreEv
);
132 bool isLegalStore(Loop
*CurLoop
, StoreInst
*SI
);
133 void collectStores(Loop
*CurLoop
, BasicBlock
*BB
,
134 SmallVectorImpl
<StoreInst
*> &Stores
);
135 bool processCopyingStore(Loop
*CurLoop
, StoreInst
*SI
, const SCEV
*BECount
);
136 bool coverLoop(Loop
*L
, SmallVectorImpl
<Instruction
*> &Insts
) const;
137 bool runOnLoopBlock(Loop
*CurLoop
, BasicBlock
*BB
, const SCEV
*BECount
,
138 SmallVectorImpl
<BasicBlock
*> &ExitBlocks
);
139 bool runOnCountableLoop(Loop
*L
);
142 const DataLayout
*DL
;
145 const TargetLibraryInfo
*TLI
;
147 bool HasMemcpy
, HasMemmove
;
150 class HexagonLoopIdiomRecognizeLegacyPass
: public LoopPass
{
154 explicit HexagonLoopIdiomRecognizeLegacyPass() : LoopPass(ID
) {
155 initializeHexagonLoopIdiomRecognizeLegacyPassPass(
156 *PassRegistry::getPassRegistry());
159 StringRef
getPassName() const override
{
160 return "Recognize Hexagon-specific loop idioms";
163 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
164 AU
.addRequired
<LoopInfoWrapperPass
>();
165 AU
.addRequiredID(LoopSimplifyID
);
166 AU
.addRequiredID(LCSSAID
);
167 AU
.addRequired
<AAResultsWrapperPass
>();
168 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
169 AU
.addRequired
<DominatorTreeWrapperPass
>();
170 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
171 AU
.addPreserved
<TargetLibraryInfoWrapperPass
>();
174 bool runOnLoop(Loop
*L
, LPPassManager
&LPM
) override
;
179 using FuncType
= std::function
<Value
*(Instruction
*, LLVMContext
&)>;
180 Rule(StringRef N
, FuncType F
) : Name(N
), Fn(F
) {}
181 StringRef Name
; // For debugging.
185 void addRule(StringRef N
, const Rule::FuncType
&F
) {
186 Rules
.push_back(Rule(N
, F
));
190 struct WorkListType
{
191 WorkListType() = default;
193 void push_back(Value
*V
) {
194 // Do not push back duplicates.
201 Value
*pop_front_val() {
202 Value
*V
= Q
.front();
208 bool empty() const { return Q
.empty(); }
211 std::deque
<Value
*> Q
;
215 using ValueSetType
= std::set
<Value
*>;
217 std::vector
<Rule
> Rules
;
221 using ValueMapType
= DenseMap
<Value
*, Value
*>;
224 ValueSetType Used
; // The set of all cloned values used by Root.
225 ValueSetType Clones
; // The set of all cloned values.
228 Context(Instruction
*Exp
)
229 : Ctx(Exp
->getParent()->getParent()->getContext()) {
233 ~Context() { cleanup(); }
235 void print(raw_ostream
&OS
, const Value
*V
) const;
236 Value
*materialize(BasicBlock
*B
, BasicBlock::iterator At
);
239 friend struct Simplifier
;
241 void initialize(Instruction
*Exp
);
244 template <typename FuncT
> void traverse(Value
*V
, FuncT F
);
245 void record(Value
*V
);
247 void unuse(Value
*V
);
249 bool equal(const Instruction
*I
, const Instruction
*J
) const;
250 Value
*find(Value
*Tree
, Value
*Sub
) const;
251 Value
*subst(Value
*Tree
, Value
*OldV
, Value
*NewV
);
252 void replace(Value
*OldV
, Value
*NewV
);
253 void link(Instruction
*I
, BasicBlock
*B
, BasicBlock::iterator At
);
256 Value
*simplify(Context
&C
);
260 PE(const Simplifier::Context
&c
, Value
*v
= nullptr) : C(c
), V(v
) {}
262 const Simplifier::Context
&C
;
267 raw_ostream
&operator<<(raw_ostream
&OS
, const PE
&P
) {
268 P
.C
.print(OS
, P
.V
? P
.V
: P
.C
.Root
);
272 } // end anonymous namespace
274 char HexagonLoopIdiomRecognizeLegacyPass::ID
= 0;
276 INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognizeLegacyPass
, "hexagon-loop-idiom",
277 "Recognize Hexagon-specific loop idioms", false, false)
278 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
279 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
280 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass
)
281 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
282 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
283 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
284 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass
)
285 INITIALIZE_PASS_END(HexagonLoopIdiomRecognizeLegacyPass
, "hexagon-loop-idiom",
286 "Recognize Hexagon-specific loop idioms", false, false)
288 template <typename FuncT
>
289 void Simplifier::Context::traverse(Value
*V
, FuncT F
) {
294 Instruction
*U
= dyn_cast
<Instruction
>(Q
.pop_front_val());
295 if (!U
|| U
->getParent())
299 for (Value
*Op
: U
->operands())
304 void Simplifier::Context::print(raw_ostream
&OS
, const Value
*V
) const {
305 const auto *U
= dyn_cast
<const Instruction
>(V
);
307 OS
<< V
<< '(' << *V
<< ')';
311 if (U
->getParent()) {
313 U
->printAsOperand(OS
, true);
318 unsigned N
= U
->getNumOperands();
321 OS
<< U
->getOpcodeName();
322 for (const Value
*Op
: U
->operands()) {
330 void Simplifier::Context::initialize(Instruction
*Exp
) {
331 // Perform a deep clone of the expression, set Root to the root
332 // of the clone, and build a map from the cloned values to the
335 BasicBlock
*Block
= Exp
->getParent();
340 Value
*V
= Q
.pop_front_val();
341 if (M
.find(V
) != M
.end())
343 if (Instruction
*U
= dyn_cast
<Instruction
>(V
)) {
344 if (isa
<PHINode
>(U
) || U
->getParent() != Block
)
346 for (Value
*Op
: U
->operands())
348 M
.insert({U
, U
->clone()});
352 for (std::pair
<Value
*,Value
*> P
: M
) {
353 Instruction
*U
= cast
<Instruction
>(P
.second
);
354 for (unsigned i
= 0, n
= U
->getNumOperands(); i
!= n
; ++i
) {
355 auto F
= M
.find(U
->getOperand(i
));
357 U
->setOperand(i
, F
->second
);
361 auto R
= M
.find(Exp
);
362 assert(R
!= M
.end());
369 void Simplifier::Context::record(Value
*V
) {
370 auto Record
= [this](Instruction
*U
) -> bool {
377 void Simplifier::Context::use(Value
*V
) {
378 auto Use
= [this](Instruction
*U
) -> bool {
385 void Simplifier::Context::unuse(Value
*V
) {
386 if (!isa
<Instruction
>(V
) || cast
<Instruction
>(V
)->getParent() != nullptr)
389 auto Unuse
= [this](Instruction
*U
) -> bool {
398 Value
*Simplifier::Context::subst(Value
*Tree
, Value
*OldV
, Value
*NewV
) {
407 Instruction
*U
= dyn_cast
<Instruction
>(Q
.pop_front_val());
408 // If U is not an instruction, or it's not a clone, skip it.
409 if (!U
|| U
->getParent())
411 for (unsigned i
= 0, n
= U
->getNumOperands(); i
!= n
; ++i
) {
412 Value
*Op
= U
->getOperand(i
);
414 U
->setOperand(i
, NewV
);
424 void Simplifier::Context::replace(Value
*OldV
, Value
*NewV
) {
431 // NewV may be a complex tree that has just been created by one of the
432 // transformation rules. We need to make sure that it is commoned with
433 // the existing Root to the maximum extent possible.
434 // Identify all subtrees of NewV (including NewV itself) that have
435 // equivalent counterparts in Root, and replace those subtrees with
436 // these counterparts.
440 Value
*V
= Q
.pop_front_val();
441 Instruction
*U
= dyn_cast
<Instruction
>(V
);
442 if (!U
|| U
->getParent())
444 if (Value
*DupV
= find(Root
, V
)) {
446 NewV
= subst(NewV
, V
, DupV
);
448 for (Value
*Op
: U
->operands())
453 // Now, simply replace OldV with NewV in Root.
454 Root
= subst(Root
, OldV
, NewV
);
458 void Simplifier::Context::cleanup() {
459 for (Value
*V
: Clones
) {
460 Instruction
*U
= cast
<Instruction
>(V
);
462 U
->dropAllReferences();
465 for (Value
*V
: Clones
) {
466 Instruction
*U
= cast
<Instruction
>(V
);
472 bool Simplifier::Context::equal(const Instruction
*I
,
473 const Instruction
*J
) const {
476 if (!I
->isSameOperationAs(J
))
479 return I
->isIdenticalTo(J
);
481 for (unsigned i
= 0, n
= I
->getNumOperands(); i
!= n
; ++i
) {
482 Value
*OpI
= I
->getOperand(i
), *OpJ
= J
->getOperand(i
);
485 auto *InI
= dyn_cast
<const Instruction
>(OpI
);
486 auto *InJ
= dyn_cast
<const Instruction
>(OpJ
);
488 if (!equal(InI
, InJ
))
490 } else if (InI
!= InJ
|| !InI
)
496 Value
*Simplifier::Context::find(Value
*Tree
, Value
*Sub
) const {
497 Instruction
*SubI
= dyn_cast
<Instruction
>(Sub
);
502 Value
*V
= Q
.pop_front_val();
505 Instruction
*U
= dyn_cast
<Instruction
>(V
);
506 if (!U
|| U
->getParent())
508 if (SubI
&& equal(SubI
, U
))
510 assert(!isa
<PHINode
>(U
));
511 for (Value
*Op
: U
->operands())
517 void Simplifier::Context::link(Instruction
*I
, BasicBlock
*B
,
518 BasicBlock::iterator At
) {
522 for (Value
*Op
: I
->operands()) {
523 if (Instruction
*OpI
= dyn_cast
<Instruction
>(Op
))
527 B
->getInstList().insert(At
, I
);
530 Value
*Simplifier::Context::materialize(BasicBlock
*B
,
531 BasicBlock::iterator At
) {
532 if (Instruction
*RootI
= dyn_cast
<Instruction
>(Root
))
537 Value
*Simplifier::simplify(Context
&C
) {
541 const unsigned Limit
= SimplifyLimit
;
544 if (Count
++ >= Limit
)
546 Instruction
*U
= dyn_cast
<Instruction
>(Q
.pop_front_val());
547 if (!U
|| U
->getParent() || !C
.Used
.count(U
))
549 bool Changed
= false;
550 for (Rule
&R
: Rules
) {
551 Value
*W
= R
.Fn(U
, C
.Ctx
);
561 for (Value
*Op
: U
->operands())
565 return Count
< Limit
? C
.Root
: nullptr;
568 //===----------------------------------------------------------------------===//
570 // Implementation of PolynomialMultiplyRecognize
572 //===----------------------------------------------------------------------===//
576 class PolynomialMultiplyRecognize
{
578 explicit PolynomialMultiplyRecognize(Loop
*loop
, const DataLayout
&dl
,
579 const DominatorTree
&dt
, const TargetLibraryInfo
&tli
,
581 : CurLoop(loop
), DL(dl
), DT(dt
), TLI(tli
), SE(se
) {}
586 using ValueSeq
= SetVector
<Value
*>;
588 IntegerType
*getPmpyType() const {
589 LLVMContext
&Ctx
= CurLoop
->getHeader()->getParent()->getContext();
590 return IntegerType::get(Ctx
, 32);
593 bool isPromotableTo(Value
*V
, IntegerType
*Ty
);
594 void promoteTo(Instruction
*In
, IntegerType
*DestTy
, BasicBlock
*LoopB
);
595 bool promoteTypes(BasicBlock
*LoopB
, BasicBlock
*ExitB
);
597 Value
*getCountIV(BasicBlock
*BB
);
598 bool findCycle(Value
*Out
, Value
*In
, ValueSeq
&Cycle
);
599 void classifyCycle(Instruction
*DivI
, ValueSeq
&Cycle
, ValueSeq
&Early
,
601 bool classifyInst(Instruction
*UseI
, ValueSeq
&Early
, ValueSeq
&Late
);
602 bool commutesWithShift(Instruction
*I
);
603 bool highBitsAreZero(Value
*V
, unsigned IterCount
);
604 bool keepsHighBitsZero(Value
*V
, unsigned IterCount
);
605 bool isOperandShifted(Instruction
*I
, Value
*Op
);
606 bool convertShiftsToLeft(BasicBlock
*LoopB
, BasicBlock
*ExitB
,
608 void cleanupLoopBody(BasicBlock
*LoopB
);
610 struct ParsedValues
{
611 ParsedValues() = default;
618 Instruction
*Res
= nullptr;
619 unsigned IterCount
= 0;
624 bool matchLeftShift(SelectInst
*SelI
, Value
*CIV
, ParsedValues
&PV
);
625 bool matchRightShift(SelectInst
*SelI
, ParsedValues
&PV
);
626 bool scanSelect(SelectInst
*SI
, BasicBlock
*LoopB
, BasicBlock
*PrehB
,
627 Value
*CIV
, ParsedValues
&PV
, bool PreScan
);
628 unsigned getInverseMxN(unsigned QP
);
629 Value
*generate(BasicBlock::iterator At
, ParsedValues
&PV
);
631 void setupPreSimplifier(Simplifier
&S
);
632 void setupPostSimplifier(Simplifier
&S
);
635 const DataLayout
&DL
;
636 const DominatorTree
&DT
;
637 const TargetLibraryInfo
&TLI
;
641 } // end anonymous namespace
643 Value
*PolynomialMultiplyRecognize::getCountIV(BasicBlock
*BB
) {
644 pred_iterator PI
= pred_begin(BB
), PE
= pred_end(BB
);
645 if (std::distance(PI
, PE
) != 2)
647 BasicBlock
*PB
= (*PI
== BB
) ? *std::next(PI
) : *PI
;
649 for (auto I
= BB
->begin(), E
= BB
->end(); I
!= E
&& isa
<PHINode
>(I
); ++I
) {
650 auto *PN
= cast
<PHINode
>(I
);
651 Value
*InitV
= PN
->getIncomingValueForBlock(PB
);
652 if (!isa
<ConstantInt
>(InitV
) || !cast
<ConstantInt
>(InitV
)->isZero())
654 Value
*IterV
= PN
->getIncomingValueForBlock(BB
);
655 auto *BO
= dyn_cast
<BinaryOperator
>(IterV
);
658 if (BO
->getOpcode() != Instruction::Add
)
660 Value
*IncV
= nullptr;
661 if (BO
->getOperand(0) == PN
)
662 IncV
= BO
->getOperand(1);
663 else if (BO
->getOperand(1) == PN
)
664 IncV
= BO
->getOperand(0);
668 if (auto *T
= dyn_cast
<ConstantInt
>(IncV
))
669 if (T
->getZExtValue() == 1)
675 static void replaceAllUsesOfWithIn(Value
*I
, Value
*J
, BasicBlock
*BB
) {
676 for (auto UI
= I
->user_begin(), UE
= I
->user_end(); UI
!= UE
;) {
677 Use
&TheUse
= UI
.getUse();
679 if (auto *II
= dyn_cast
<Instruction
>(TheUse
.getUser()))
680 if (BB
== II
->getParent())
681 II
->replaceUsesOfWith(I
, J
);
685 bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst
*SelI
,
686 Value
*CIV
, ParsedValues
&PV
) {
687 // Match the following:
688 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R
689 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i)
690 // The condition may also check for equality with the masked value, i.e
691 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R
692 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i);
694 Value
*CondV
= SelI
->getCondition();
695 Value
*TrueV
= SelI
->getTrueValue();
696 Value
*FalseV
= SelI
->getFalseValue();
698 using namespace PatternMatch
;
700 CmpInst::Predicate P
;
701 Value
*A
= nullptr, *B
= nullptr, *C
= nullptr;
703 if (!match(CondV
, m_ICmp(P
, m_And(m_Value(A
), m_Value(B
)), m_Value(C
))) &&
704 !match(CondV
, m_ICmp(P
, m_Value(C
), m_And(m_Value(A
), m_Value(B
)))))
706 if (P
!= CmpInst::ICMP_EQ
&& P
!= CmpInst::ICMP_NE
)
708 // Matched: select (A & B) == C ? ... : ...
709 // select (A & B) != C ? ... : ...
711 Value
*X
= nullptr, *Sh1
= nullptr;
712 // Check (A & B) for (X & (1 << i)):
713 if (match(A
, m_Shl(m_One(), m_Specific(CIV
)))) {
716 } else if (match(B
, m_Shl(m_One(), m_Specific(CIV
)))) {
720 // TODO: Could also check for an induction variable containing single
721 // bit shifted left by 1 in each iteration.
727 // Check C against the possible values for comparison: 0 and (1 << i):
728 if (match(C
, m_Zero()))
729 TrueIfZero
= (P
== CmpInst::ICMP_EQ
);
731 TrueIfZero
= (P
== CmpInst::ICMP_NE
);
736 // select (X & (1 << i)) ? ... : ...
737 // including variations of the check against zero/non-zero value.
739 Value
*ShouldSameV
= nullptr, *ShouldXoredV
= nullptr;
742 ShouldXoredV
= FalseV
;
744 ShouldSameV
= FalseV
;
745 ShouldXoredV
= TrueV
;
748 Value
*Q
= nullptr, *R
= nullptr, *Y
= nullptr, *Z
= nullptr;
750 if (match(ShouldXoredV
, m_Xor(m_Value(Y
), m_Value(Z
)))) {
751 // Matched: select +++ ? ... : Y ^ Z
752 // select +++ ? Y ^ Z : ...
753 // where +++ denotes previously checked matches.
754 if (ShouldSameV
== Y
)
756 else if (ShouldSameV
== Z
)
761 // Matched: select +++ ? R : R ^ T
762 // select +++ ? R ^ T : R
763 // depending on TrueIfZero.
765 } else if (match(ShouldSameV
, m_Zero())) {
766 // Matched: select +++ ? 0 : ...
767 // select +++ ? ... : 0
768 if (!SelI
->hasOneUse())
771 // Matched: select +++ ? 0 : T
772 // select +++ ? T : 0
774 Value
*U
= *SelI
->user_begin();
775 if (!match(U
, m_Xor(m_Specific(SelI
), m_Value(R
))) &&
776 !match(U
, m_Xor(m_Value(R
), m_Specific(SelI
))))
778 // Matched: xor (select +++ ? 0 : T), R
779 // xor (select +++ ? T : 0), R
783 // The xor input value T is isolated into its own match so that it could
784 // be checked against an induction variable containing a shifted bit
786 // For now, check against (Q << i).
787 if (!match(T
, m_Shl(m_Value(Q
), m_Specific(CIV
))) &&
788 !match(T
, m_Shl(m_ZExt(m_Value(Q
)), m_ZExt(m_Specific(CIV
)))))
790 // Matched: select +++ ? R : R ^ (Q << i)
791 // select +++ ? R ^ (Q << i) : R
800 bool PolynomialMultiplyRecognize::matchRightShift(SelectInst
*SelI
,
802 // Match the following:
803 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1)
804 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q
805 // The condition may also check for equality with the masked value, i.e
806 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1)
807 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q
809 Value
*CondV
= SelI
->getCondition();
810 Value
*TrueV
= SelI
->getTrueValue();
811 Value
*FalseV
= SelI
->getFalseValue();
813 using namespace PatternMatch
;
816 CmpInst::Predicate P
;
819 if (match(CondV
, m_ICmp(P
, m_Value(C
), m_Zero())) ||
820 match(CondV
, m_ICmp(P
, m_Zero(), m_Value(C
)))) {
821 if (P
!= CmpInst::ICMP_EQ
&& P
!= CmpInst::ICMP_NE
)
823 // Matched: select C == 0 ? ... : ...
824 // select C != 0 ? ... : ...
825 TrueIfZero
= (P
== CmpInst::ICMP_EQ
);
826 } else if (match(CondV
, m_ICmp(P
, m_Value(C
), m_One())) ||
827 match(CondV
, m_ICmp(P
, m_One(), m_Value(C
)))) {
828 if (P
!= CmpInst::ICMP_EQ
&& P
!= CmpInst::ICMP_NE
)
830 // Matched: select C == 1 ? ... : ...
831 // select C != 1 ? ... : ...
832 TrueIfZero
= (P
== CmpInst::ICMP_NE
);
837 if (!match(C
, m_And(m_Value(X
), m_One())) &&
838 !match(C
, m_And(m_One(), m_Value(X
))))
840 // Matched: select (X & 1) == +++ ? ... : ...
841 // select (X & 1) != +++ ? ... : ...
843 Value
*R
= nullptr, *Q
= nullptr;
845 // The select's condition is true if the tested bit is 0.
846 // TrueV must be the shift, FalseV must be the xor.
847 if (!match(TrueV
, m_LShr(m_Value(R
), m_One())))
849 // Matched: select +++ ? (R >> 1) : ...
850 if (!match(FalseV
, m_Xor(m_Specific(TrueV
), m_Value(Q
))) &&
851 !match(FalseV
, m_Xor(m_Value(Q
), m_Specific(TrueV
))))
853 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q
856 // The select's condition is true if the tested bit is 1.
857 // TrueV must be the xor, FalseV must be the shift.
858 if (!match(FalseV
, m_LShr(m_Value(R
), m_One())))
860 // Matched: select +++ ? ... : (R >> 1)
861 if (!match(TrueV
, m_Xor(m_Specific(FalseV
), m_Value(Q
))) &&
862 !match(TrueV
, m_Xor(m_Value(Q
), m_Specific(FalseV
))))
864 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1)
875 bool PolynomialMultiplyRecognize::scanSelect(SelectInst
*SelI
,
876 BasicBlock
*LoopB
, BasicBlock
*PrehB
, Value
*CIV
, ParsedValues
&PV
,
878 using namespace PatternMatch
;
880 // The basic pattern for R = P.Q is:
883 // if (P & (1 << i)) ; test-bit(P, i)
886 // Similarly, the basic pattern for R = (P/Q).Q - P
892 // There exist idioms, where instead of Q being shifted left, P is shifted
893 // right. This produces a result that is shifted right by 32 bits (the
894 // non-shifted result is 64-bit).
896 // For R = P.Q, this would be:
900 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must
901 // else ; be shifted by 1, not i.
904 // And for the inverse:
912 // The left-shifting idioms share the same pattern:
913 // select (X & (1 << i)) ? R ^ (Q << i) : R
914 // Similarly for right-shifting idioms:
915 // select (X & 1) ? (R >> 1) ^ Q
917 if (matchLeftShift(SelI
, CIV
, PV
)) {
918 // If this is a pre-scan, getting this far is sufficient.
922 // Need to make sure that the SelI goes back into R.
923 auto *RPhi
= dyn_cast
<PHINode
>(PV
.R
);
926 if (SelI
!= RPhi
->getIncomingValueForBlock(LoopB
))
930 // If X is loop invariant, it must be the input polynomial, and the
931 // idiom is the basic polynomial multiply.
932 if (CurLoop
->isLoopInvariant(PV
.X
)) {
936 // X is not loop invariant. If X == R, this is the inverse pmpy.
937 // Otherwise, check for an xor with an invariant value. If the
938 // variable argument to the xor is R, then this is still a valid
942 Value
*Var
= nullptr, *Inv
= nullptr, *X1
= nullptr, *X2
= nullptr;
943 if (!match(PV
.X
, m_Xor(m_Value(X1
), m_Value(X2
))))
945 auto *I1
= dyn_cast
<Instruction
>(X1
);
946 auto *I2
= dyn_cast
<Instruction
>(X2
);
947 if (!I1
|| I1
->getParent() != LoopB
) {
950 } else if (!I2
|| I2
->getParent() != LoopB
) {
959 // The input polynomial P still needs to be determined. It will be
960 // the entry value of R.
961 Value
*EntryP
= RPhi
->getIncomingValueForBlock(PrehB
);
968 if (matchRightShift(SelI
, PV
)) {
969 // If this is an inverse pattern, the Q polynomial must be known at
971 if (PV
.Inv
&& !isa
<ConstantInt
>(PV
.Q
))
975 // There is no exact matching of right-shift pmpy.
982 bool PolynomialMultiplyRecognize::isPromotableTo(Value
*Val
,
983 IntegerType
*DestTy
) {
984 IntegerType
*T
= dyn_cast
<IntegerType
>(Val
->getType());
985 if (!T
|| T
->getBitWidth() > DestTy
->getBitWidth())
987 if (T
->getBitWidth() == DestTy
->getBitWidth())
989 // Non-instructions are promotable. The reason why an instruction may not
990 // be promotable is that it may produce a different result if its operands
991 // and the result are promoted, for example, it may produce more non-zero
992 // bits. While it would still be possible to represent the proper result
993 // in a wider type, it may require adding additional instructions (which
994 // we don't want to do).
995 Instruction
*In
= dyn_cast
<Instruction
>(Val
);
998 // The bitwidth of the source type is smaller than the destination.
999 // Check if the individual operation can be promoted.
1000 switch (In
->getOpcode()) {
1001 case Instruction::PHI
:
1002 case Instruction::ZExt
:
1003 case Instruction::And
:
1004 case Instruction::Or
:
1005 case Instruction::Xor
:
1006 case Instruction::LShr
: // Shift right is ok.
1007 case Instruction::Select
:
1008 case Instruction::Trunc
:
1010 case Instruction::ICmp
:
1011 if (CmpInst
*CI
= cast
<CmpInst
>(In
))
1012 return CI
->isEquality() || CI
->isUnsigned();
1013 llvm_unreachable("Cast failed unexpectedly");
1014 case Instruction::Add
:
1015 return In
->hasNoSignedWrap() && In
->hasNoUnsignedWrap();
1020 void PolynomialMultiplyRecognize::promoteTo(Instruction
*In
,
1021 IntegerType
*DestTy
, BasicBlock
*LoopB
) {
1022 Type
*OrigTy
= In
->getType();
1023 assert(!OrigTy
->isVoidTy() && "Invalid instruction to promote");
1025 // Leave boolean values alone.
1026 if (!In
->getType()->isIntegerTy(1))
1027 In
->mutateType(DestTy
);
1028 unsigned DestBW
= DestTy
->getBitWidth();
1031 if (PHINode
*P
= dyn_cast
<PHINode
>(In
)) {
1032 unsigned N
= P
->getNumIncomingValues();
1033 for (unsigned i
= 0; i
!= N
; ++i
) {
1034 BasicBlock
*InB
= P
->getIncomingBlock(i
);
1037 Value
*InV
= P
->getIncomingValue(i
);
1038 IntegerType
*Ty
= cast
<IntegerType
>(InV
->getType());
1039 // Do not promote values in PHI nodes of type i1.
1040 if (Ty
!= P
->getType()) {
1041 // If the value type does not match the PHI type, the PHI type
1042 // must have been promoted.
1043 assert(Ty
->getBitWidth() < DestBW
);
1044 InV
= IRBuilder
<>(InB
->getTerminator()).CreateZExt(InV
, DestTy
);
1045 P
->setIncomingValue(i
, InV
);
1048 } else if (ZExtInst
*Z
= dyn_cast
<ZExtInst
>(In
)) {
1049 Value
*Op
= Z
->getOperand(0);
1050 if (Op
->getType() == Z
->getType())
1051 Z
->replaceAllUsesWith(Op
);
1052 Z
->eraseFromParent();
1055 if (TruncInst
*T
= dyn_cast
<TruncInst
>(In
)) {
1056 IntegerType
*TruncTy
= cast
<IntegerType
>(OrigTy
);
1057 Value
*Mask
= ConstantInt::get(DestTy
, (1u << TruncTy
->getBitWidth()) - 1);
1058 Value
*And
= IRBuilder
<>(In
).CreateAnd(T
->getOperand(0), Mask
);
1059 T
->replaceAllUsesWith(And
);
1060 T
->eraseFromParent();
1064 // Promote immediates.
1065 for (unsigned i
= 0, n
= In
->getNumOperands(); i
!= n
; ++i
) {
1066 if (ConstantInt
*CI
= dyn_cast
<ConstantInt
>(In
->getOperand(i
)))
1067 if (CI
->getType()->getBitWidth() < DestBW
)
1068 In
->setOperand(i
, ConstantInt::get(DestTy
, CI
->getZExtValue()));
1072 bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock
*LoopB
,
1073 BasicBlock
*ExitB
) {
1075 // Skip loops where the exit block has more than one predecessor. The values
1076 // coming from the loop block will be promoted to another type, and so the
1077 // values coming into the exit block from other predecessors would also have
1079 if (!ExitB
|| (ExitB
->getSinglePredecessor() != LoopB
))
1081 IntegerType
*DestTy
= getPmpyType();
1082 // Check if the exit values have types that are no wider than the type
1083 // that we want to promote to.
1084 unsigned DestBW
= DestTy
->getBitWidth();
1085 for (PHINode
&P
: ExitB
->phis()) {
1086 if (P
.getNumIncomingValues() != 1)
1088 assert(P
.getIncomingBlock(0) == LoopB
);
1089 IntegerType
*T
= dyn_cast
<IntegerType
>(P
.getType());
1090 if (!T
|| T
->getBitWidth() > DestBW
)
1094 // Check all instructions in the loop.
1095 for (Instruction
&In
: *LoopB
)
1096 if (!In
.isTerminator() && !isPromotableTo(&In
, DestTy
))
1099 // Perform the promotion.
1100 std::vector
<Instruction
*> LoopIns
;
1101 std::transform(LoopB
->begin(), LoopB
->end(), std::back_inserter(LoopIns
),
1102 [](Instruction
&In
) { return &In
; });
1103 for (Instruction
*In
: LoopIns
)
1104 if (!In
->isTerminator())
1105 promoteTo(In
, DestTy
, LoopB
);
1107 // Fix up the PHI nodes in the exit block.
1108 Instruction
*EndI
= ExitB
->getFirstNonPHI();
1109 BasicBlock::iterator End
= EndI
? EndI
->getIterator() : ExitB
->end();
1110 for (auto I
= ExitB
->begin(); I
!= End
; ++I
) {
1111 PHINode
*P
= dyn_cast
<PHINode
>(I
);
1114 Type
*Ty0
= P
->getIncomingValue(0)->getType();
1115 Type
*PTy
= P
->getType();
1117 assert(Ty0
== DestTy
);
1118 // In order to create the trunc, P must have the promoted type.
1120 Value
*T
= IRBuilder
<>(ExitB
, End
).CreateTrunc(P
, PTy
);
1121 // In order for the RAUW to work, the types of P and T must match.
1123 P
->replaceAllUsesWith(T
);
1124 // Final update of the P's type.
1126 cast
<Instruction
>(T
)->setOperand(0, P
);
1133 bool PolynomialMultiplyRecognize::findCycle(Value
*Out
, Value
*In
,
1135 // Out = ..., In, ...
1139 auto *BB
= cast
<Instruction
>(Out
)->getParent();
1140 bool HadPhi
= false;
1142 for (auto U
: Out
->users()) {
1143 auto *I
= dyn_cast
<Instruction
>(&*U
);
1144 if (I
== nullptr || I
->getParent() != BB
)
1146 // Make sure that there are no multi-iteration cycles, e.g.
1149 // The cycle p1->p2->p1 would span two loop iterations.
1150 // Check that there is only one phi in the cycle.
1151 bool IsPhi
= isa
<PHINode
>(I
);
1152 if (IsPhi
&& HadPhi
)
1158 if (findCycle(I
, In
, Cycle
))
1162 return !Cycle
.empty();
1165 void PolynomialMultiplyRecognize::classifyCycle(Instruction
*DivI
,
1166 ValueSeq
&Cycle
, ValueSeq
&Early
, ValueSeq
&Late
) {
1167 // All the values in the cycle that are between the phi node and the
1168 // divider instruction will be classified as "early", all other values
1172 unsigned I
, N
= Cycle
.size();
1173 for (I
= 0; I
< N
; ++I
) {
1174 Value
*V
= Cycle
[I
];
1177 else if (!isa
<PHINode
>(V
))
1179 // Stop if found either.
1182 // "I" is the index of either DivI or the phi node, whichever was first.
1183 // "E" is "false" or "true" respectively.
1184 ValueSeq
&First
= !IsE
? Early
: Late
;
1185 for (unsigned J
= 0; J
< I
; ++J
)
1186 First
.insert(Cycle
[J
]);
1188 ValueSeq
&Second
= IsE
? Early
: Late
;
1189 Second
.insert(Cycle
[I
]);
1190 for (++I
; I
< N
; ++I
) {
1191 Value
*V
= Cycle
[I
];
1192 if (DivI
== V
|| isa
<PHINode
>(V
))
1198 First
.insert(Cycle
[I
]);
1201 bool PolynomialMultiplyRecognize::classifyInst(Instruction
*UseI
,
1202 ValueSeq
&Early
, ValueSeq
&Late
) {
1203 // Select is an exception, since the condition value does not have to be
1204 // classified in the same way as the true/false values. The true/false
1205 // values do have to be both early or both late.
1206 if (UseI
->getOpcode() == Instruction::Select
) {
1207 Value
*TV
= UseI
->getOperand(1), *FV
= UseI
->getOperand(2);
1208 if (Early
.count(TV
) || Early
.count(FV
)) {
1209 if (Late
.count(TV
) || Late
.count(FV
))
1212 } else if (Late
.count(TV
) || Late
.count(FV
)) {
1213 if (Early
.count(TV
) || Early
.count(FV
))
1220 // Not sure what would be the example of this, but the code below relies
1221 // on having at least one operand.
1222 if (UseI
->getNumOperands() == 0)
1225 bool AE
= true, AL
= true;
1226 for (auto &I
: UseI
->operands()) {
1227 if (Early
.count(&*I
))
1229 else if (Late
.count(&*I
))
1232 // If the operands appear "all early" and "all late" at the same time,
1233 // then it means that none of them are actually classified as either.
1234 // This is harmless.
1237 // Conversely, if they are neither "all early" nor "all late", then
1238 // we have a mixture of early and late operands that is not a known
1243 // Check that we have covered the two special cases.
1253 bool PolynomialMultiplyRecognize::commutesWithShift(Instruction
*I
) {
1254 switch (I
->getOpcode()) {
1255 case Instruction::And
:
1256 case Instruction::Or
:
1257 case Instruction::Xor
:
1258 case Instruction::LShr
:
1259 case Instruction::Shl
:
1260 case Instruction::Select
:
1261 case Instruction::ICmp
:
1262 case Instruction::PHI
:
1270 bool PolynomialMultiplyRecognize::highBitsAreZero(Value
*V
,
1271 unsigned IterCount
) {
1272 auto *T
= dyn_cast
<IntegerType
>(V
->getType());
1276 KnownBits
Known(T
->getBitWidth());
1277 computeKnownBits(V
, Known
, DL
);
1278 return Known
.countMinLeadingZeros() >= IterCount
;
1281 bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value
*V
,
1282 unsigned IterCount
) {
1283 // Assume that all inputs to the value have the high bits zero.
1284 // Check if the value itself preserves the zeros in the high bits.
1285 if (auto *C
= dyn_cast
<ConstantInt
>(V
))
1286 return C
->getValue().countLeadingZeros() >= IterCount
;
1288 if (auto *I
= dyn_cast
<Instruction
>(V
)) {
1289 switch (I
->getOpcode()) {
1290 case Instruction::And
:
1291 case Instruction::Or
:
1292 case Instruction::Xor
:
1293 case Instruction::LShr
:
1294 case Instruction::Select
:
1295 case Instruction::ICmp
:
1296 case Instruction::PHI
:
1297 case Instruction::ZExt
:
1305 bool PolynomialMultiplyRecognize::isOperandShifted(Instruction
*I
, Value
*Op
) {
1306 unsigned Opc
= I
->getOpcode();
1307 if (Opc
== Instruction::Shl
|| Opc
== Instruction::LShr
)
1308 return Op
!= I
->getOperand(1);
1312 bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock
*LoopB
,
1313 BasicBlock
*ExitB
, unsigned IterCount
) {
1314 Value
*CIV
= getCountIV(LoopB
);
1317 auto *CIVTy
= dyn_cast
<IntegerType
>(CIV
->getType());
1318 if (CIVTy
== nullptr)
1322 ValueSeq Early
, Late
, Cycled
;
1324 // Find all value cycles that contain logical right shifts by 1.
1325 for (Instruction
&I
: *LoopB
) {
1326 using namespace PatternMatch
;
1329 if (!match(&I
, m_LShr(m_Value(V
), m_One())))
1332 if (!findCycle(&I
, V
, C
))
1337 classifyCycle(&I
, C
, Early
, Late
);
1338 Cycled
.insert(C
.begin(), C
.end());
1342 // Find the set of all values affected by the shift cycles, i.e. all
1343 // cycled values, and (recursively) all their users.
1344 ValueSeq
Users(Cycled
.begin(), Cycled
.end());
1345 for (unsigned i
= 0; i
< Users
.size(); ++i
) {
1346 Value
*V
= Users
[i
];
1347 if (!isa
<IntegerType
>(V
->getType()))
1349 auto *R
= cast
<Instruction
>(V
);
1350 // If the instruction does not commute with shifts, the loop cannot
1352 if (!commutesWithShift(R
))
1354 for (User
*U
: R
->users()) {
1355 auto *T
= cast
<Instruction
>(U
);
1356 // Skip users from outside of the loop. They will be handled later.
1357 // Also, skip the right-shifts and phi nodes, since they mix early
1359 if (T
->getParent() != LoopB
|| RShifts
.count(T
) || isa
<PHINode
>(T
))
1363 if (!classifyInst(T
, Early
, Late
))
1371 // Verify that high bits remain zero.
1372 ValueSeq
Internal(Users
.begin(), Users
.end());
1374 for (unsigned i
= 0; i
< Internal
.size(); ++i
) {
1375 auto *R
= dyn_cast
<Instruction
>(Internal
[i
]);
1378 for (Value
*Op
: R
->operands()) {
1379 auto *T
= dyn_cast
<Instruction
>(Op
);
1380 if (T
&& T
->getParent() != LoopB
)
1383 Internal
.insert(Op
);
1386 for (Value
*V
: Inputs
)
1387 if (!highBitsAreZero(V
, IterCount
))
1389 for (Value
*V
: Internal
)
1390 if (!keepsHighBitsZero(V
, IterCount
))
1393 // Finally, the work can be done. Unshift each user.
1394 IRBuilder
<> IRB(LoopB
);
1395 std::map
<Value
*,Value
*> ShiftMap
;
1397 using CastMapType
= std::map
<std::pair
<Value
*, Type
*>, Value
*>;
1399 CastMapType CastMap
;
1401 auto upcast
= [] (CastMapType
&CM
, IRBuilder
<> &IRB
, Value
*V
,
1402 IntegerType
*Ty
) -> Value
* {
1403 auto H
= CM
.find(std::make_pair(V
, Ty
));
1406 Value
*CV
= IRB
.CreateIntCast(V
, Ty
, false);
1407 CM
.insert(std::make_pair(std::make_pair(V
, Ty
), CV
));
1411 for (auto I
= LoopB
->begin(), E
= LoopB
->end(); I
!= E
; ++I
) {
1412 using namespace PatternMatch
;
1414 if (isa
<PHINode
>(I
) || !Users
.count(&*I
))
1419 if (match(&*I
, m_LShr(m_Value(V
), m_One()))) {
1420 replaceAllUsesOfWithIn(&*I
, V
, LoopB
);
1423 // For each non-cycled operand, replace it with the corresponding
1424 // value shifted left.
1425 for (auto &J
: I
->operands()) {
1426 Value
*Op
= J
.get();
1427 if (!isOperandShifted(&*I
, Op
))
1429 if (Users
.count(Op
))
1431 // Skip shifting zeros.
1432 if (isa
<ConstantInt
>(Op
) && cast
<ConstantInt
>(Op
)->isZero())
1434 // Check if we have already generated a shift for this value.
1435 auto F
= ShiftMap
.find(Op
);
1436 Value
*W
= (F
!= ShiftMap
.end()) ? F
->second
: nullptr;
1438 IRB
.SetInsertPoint(&*I
);
1439 // First, the shift amount will be CIV or CIV+1, depending on
1440 // whether the value is early or late. Instead of creating CIV+1,
1441 // do a single shift of the value.
1442 Value
*ShAmt
= CIV
, *ShVal
= Op
;
1443 auto *VTy
= cast
<IntegerType
>(ShVal
->getType());
1444 auto *ATy
= cast
<IntegerType
>(ShAmt
->getType());
1445 if (Late
.count(&*I
))
1446 ShVal
= IRB
.CreateShl(Op
, ConstantInt::get(VTy
, 1));
1447 // Second, the types of the shifted value and the shift amount
1450 if (VTy
->getBitWidth() < ATy
->getBitWidth())
1451 ShVal
= upcast(CastMap
, IRB
, ShVal
, ATy
);
1453 ShAmt
= upcast(CastMap
, IRB
, ShAmt
, VTy
);
1455 // Ready to generate the shift and memoize it.
1456 W
= IRB
.CreateShl(ShVal
, ShAmt
);
1457 ShiftMap
.insert(std::make_pair(Op
, W
));
1459 I
->replaceUsesOfWith(Op
, W
);
1463 // Update the users outside of the loop to account for having left
1464 // shifts. They would normally be shifted right in the loop, so shift
1465 // them right after the loop exit.
1466 // Take advantage of the loop-closed SSA form, which has all the post-
1467 // loop values in phi nodes.
1468 IRB
.SetInsertPoint(ExitB
, ExitB
->getFirstInsertionPt());
1469 for (auto P
= ExitB
->begin(), Q
= ExitB
->end(); P
!= Q
; ++P
) {
1470 if (!isa
<PHINode
>(P
))
1472 auto *PN
= cast
<PHINode
>(P
);
1473 Value
*U
= PN
->getIncomingValueForBlock(LoopB
);
1474 if (!Users
.count(U
))
1476 Value
*S
= IRB
.CreateLShr(PN
, ConstantInt::get(PN
->getType(), IterCount
));
1477 PN
->replaceAllUsesWith(S
);
1478 // The above RAUW will create
1479 // S = lshr S, IterCount
1480 // so we need to fix it back into
1481 // S = lshr PN, IterCount
1482 cast
<User
>(S
)->replaceUsesOfWith(S
, PN
);
1488 void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock
*LoopB
) {
1489 for (auto &I
: *LoopB
)
1490 if (Value
*SV
= SimplifyInstruction(&I
, {DL
, &TLI
, &DT
}))
1491 I
.replaceAllUsesWith(SV
);
1493 for (Instruction
&I
: llvm::make_early_inc_range(*LoopB
))
1494 RecursivelyDeleteTriviallyDeadInstructions(&I
, &TLI
);
1497 unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP
) {
1498 // Arrays of coefficients of Q and the inverse, C.
1499 // Q[i] = coefficient at x^i.
1500 std::array
<char,32> Q
, C
;
1502 for (unsigned i
= 0; i
< 32; ++i
) {
1508 // Find C, such that
1509 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1
1511 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the
1512 // operations * and + are & and ^ respectively.
1514 // Find C[i] recursively, by comparing i-th coefficient in the product
1515 // with 0 (or 1 for i=0).
1517 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1.
1519 for (unsigned i
= 1; i
< 32; ++i
) {
1520 // Solve for C[i] in:
1521 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0
1522 // This is equivalent to
1523 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0
1525 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i]
1527 for (unsigned j
= 0; j
< i
; ++j
)
1528 T
= T
^ (C
[j
] & Q
[i
-j
]);
1533 for (unsigned i
= 0; i
< 32; ++i
)
1540 Value
*PolynomialMultiplyRecognize::generate(BasicBlock::iterator At
,
1542 IRBuilder
<> B(&*At
);
1543 Module
*M
= At
->getParent()->getParent()->getParent();
1544 Function
*PMF
= Intrinsic::getDeclaration(M
, Intrinsic::hexagon_M4_pmpyw
);
1546 Value
*P
= PV
.P
, *Q
= PV
.Q
, *P0
= P
;
1547 unsigned IC
= PV
.IterCount
;
1549 if (PV
.M
!= nullptr)
1550 P0
= P
= B
.CreateXor(P
, PV
.M
);
1552 // Create a bit mask to clear the high bits beyond IterCount.
1553 auto *BMI
= ConstantInt::get(P
->getType(), APInt::getLowBitsSet(32, IC
));
1555 if (PV
.IterCount
!= 32)
1556 P
= B
.CreateAnd(P
, BMI
);
1559 auto *QI
= dyn_cast
<ConstantInt
>(PV
.Q
);
1560 assert(QI
&& QI
->getBitWidth() <= 32);
1562 // Again, clearing bits beyond IterCount.
1563 unsigned M
= (1 << PV
.IterCount
) - 1;
1564 unsigned Tmp
= (QI
->getZExtValue() | 1) & M
;
1565 unsigned QV
= getInverseMxN(Tmp
) & M
;
1566 auto *QVI
= ConstantInt::get(QI
->getType(), QV
);
1567 P
= B
.CreateCall(PMF
, {P
, QVI
});
1568 P
= B
.CreateTrunc(P
, QI
->getType());
1570 P
= B
.CreateAnd(P
, BMI
);
1573 Value
*R
= B
.CreateCall(PMF
, {P
, Q
});
1575 if (PV
.M
!= nullptr)
1576 R
= B
.CreateXor(R
, B
.CreateIntCast(P0
, R
->getType(), false));
1581 static bool hasZeroSignBit(const Value
*V
) {
1582 if (const auto *CI
= dyn_cast
<const ConstantInt
>(V
))
1583 return (CI
->getType()->getSignBit() & CI
->getSExtValue()) == 0;
1584 const Instruction
*I
= dyn_cast
<const Instruction
>(V
);
1587 switch (I
->getOpcode()) {
1588 case Instruction::LShr
:
1589 if (const auto SI
= dyn_cast
<const ConstantInt
>(I
->getOperand(1)))
1590 return SI
->getZExtValue() > 0;
1592 case Instruction::Or
:
1593 case Instruction::Xor
:
1594 return hasZeroSignBit(I
->getOperand(0)) &&
1595 hasZeroSignBit(I
->getOperand(1));
1596 case Instruction::And
:
1597 return hasZeroSignBit(I
->getOperand(0)) ||
1598 hasZeroSignBit(I
->getOperand(1));
1603 void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier
&S
) {
1604 S
.addRule("sink-zext",
1605 // Sink zext past bitwise operations.
1606 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1607 if (I
->getOpcode() != Instruction::ZExt
)
1609 Instruction
*T
= dyn_cast
<Instruction
>(I
->getOperand(0));
1612 switch (T
->getOpcode()) {
1613 case Instruction::And
:
1614 case Instruction::Or
:
1615 case Instruction::Xor
:
1621 return B
.CreateBinOp(cast
<BinaryOperator
>(T
)->getOpcode(),
1622 B
.CreateZExt(T
->getOperand(0), I
->getType()),
1623 B
.CreateZExt(T
->getOperand(1), I
->getType()));
1625 S
.addRule("xor/and -> and/xor",
1626 // (xor (and x a) (and y a)) -> (and (xor x y) a)
1627 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1628 if (I
->getOpcode() != Instruction::Xor
)
1630 Instruction
*And0
= dyn_cast
<Instruction
>(I
->getOperand(0));
1631 Instruction
*And1
= dyn_cast
<Instruction
>(I
->getOperand(1));
1634 if (And0
->getOpcode() != Instruction::And
||
1635 And1
->getOpcode() != Instruction::And
)
1637 if (And0
->getOperand(1) != And1
->getOperand(1))
1640 return B
.CreateAnd(B
.CreateXor(And0
->getOperand(0), And1
->getOperand(0)),
1641 And0
->getOperand(1));
1643 S
.addRule("sink binop into select",
1644 // (Op (select c x y) z) -> (select c (Op x z) (Op y z))
1645 // (Op x (select c y z)) -> (select c (Op x y) (Op x z))
1646 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1647 BinaryOperator
*BO
= dyn_cast
<BinaryOperator
>(I
);
1650 Instruction::BinaryOps Op
= BO
->getOpcode();
1651 if (SelectInst
*Sel
= dyn_cast
<SelectInst
>(BO
->getOperand(0))) {
1653 Value
*X
= Sel
->getTrueValue(), *Y
= Sel
->getFalseValue();
1654 Value
*Z
= BO
->getOperand(1);
1655 return B
.CreateSelect(Sel
->getCondition(),
1656 B
.CreateBinOp(Op
, X
, Z
),
1657 B
.CreateBinOp(Op
, Y
, Z
));
1659 if (SelectInst
*Sel
= dyn_cast
<SelectInst
>(BO
->getOperand(1))) {
1661 Value
*X
= BO
->getOperand(0);
1662 Value
*Y
= Sel
->getTrueValue(), *Z
= Sel
->getFalseValue();
1663 return B
.CreateSelect(Sel
->getCondition(),
1664 B
.CreateBinOp(Op
, X
, Y
),
1665 B
.CreateBinOp(Op
, X
, Z
));
1669 S
.addRule("fold select-select",
1670 // (select c (select c x y) z) -> (select c x z)
1671 // (select c x (select c y z)) -> (select c x z)
1672 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1673 SelectInst
*Sel
= dyn_cast
<SelectInst
>(I
);
1677 Value
*C
= Sel
->getCondition();
1678 if (SelectInst
*Sel0
= dyn_cast
<SelectInst
>(Sel
->getTrueValue())) {
1679 if (Sel0
->getCondition() == C
)
1680 return B
.CreateSelect(C
, Sel0
->getTrueValue(), Sel
->getFalseValue());
1682 if (SelectInst
*Sel1
= dyn_cast
<SelectInst
>(Sel
->getFalseValue())) {
1683 if (Sel1
->getCondition() == C
)
1684 return B
.CreateSelect(C
, Sel
->getTrueValue(), Sel1
->getFalseValue());
1688 S
.addRule("or-signbit -> xor-signbit",
1689 // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0)
1690 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1691 if (I
->getOpcode() != Instruction::Or
)
1693 ConstantInt
*Msb
= dyn_cast
<ConstantInt
>(I
->getOperand(1));
1694 if (!Msb
|| Msb
->getZExtValue() != Msb
->getType()->getSignBit())
1696 if (!hasZeroSignBit(I
->getOperand(0)))
1698 return IRBuilder
<>(Ctx
).CreateXor(I
->getOperand(0), Msb
);
1700 S
.addRule("sink lshr into binop",
1701 // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c))
1702 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1703 if (I
->getOpcode() != Instruction::LShr
)
1705 BinaryOperator
*BitOp
= dyn_cast
<BinaryOperator
>(I
->getOperand(0));
1708 switch (BitOp
->getOpcode()) {
1709 case Instruction::And
:
1710 case Instruction::Or
:
1711 case Instruction::Xor
:
1717 Value
*S
= I
->getOperand(1);
1718 return B
.CreateBinOp(BitOp
->getOpcode(),
1719 B
.CreateLShr(BitOp
->getOperand(0), S
),
1720 B
.CreateLShr(BitOp
->getOperand(1), S
));
1722 S
.addRule("expose bitop-const",
1723 // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b))
1724 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1725 auto IsBitOp
= [](unsigned Op
) -> bool {
1727 case Instruction::And
:
1728 case Instruction::Or
:
1729 case Instruction::Xor
:
1734 BinaryOperator
*BitOp1
= dyn_cast
<BinaryOperator
>(I
);
1735 if (!BitOp1
|| !IsBitOp(BitOp1
->getOpcode()))
1737 BinaryOperator
*BitOp2
= dyn_cast
<BinaryOperator
>(BitOp1
->getOperand(0));
1738 if (!BitOp2
|| !IsBitOp(BitOp2
->getOpcode()))
1740 ConstantInt
*CA
= dyn_cast
<ConstantInt
>(BitOp2
->getOperand(1));
1741 ConstantInt
*CB
= dyn_cast
<ConstantInt
>(BitOp1
->getOperand(1));
1745 Value
*X
= BitOp2
->getOperand(0);
1746 return B
.CreateBinOp(BitOp2
->getOpcode(), X
,
1747 B
.CreateBinOp(BitOp1
->getOpcode(), CA
, CB
));
1751 void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier
&S
) {
1752 S
.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a",
1753 [](Instruction
*I
, LLVMContext
&Ctx
) -> Value
* {
1754 if (I
->getOpcode() != Instruction::And
)
1756 Instruction
*Xor
= dyn_cast
<Instruction
>(I
->getOperand(0));
1757 ConstantInt
*C0
= dyn_cast
<ConstantInt
>(I
->getOperand(1));
1760 if (Xor
->getOpcode() != Instruction::Xor
)
1762 Instruction
*And0
= dyn_cast
<Instruction
>(Xor
->getOperand(0));
1763 Instruction
*And1
= dyn_cast
<Instruction
>(Xor
->getOperand(1));
1764 // Pick the first non-null and.
1765 if (!And0
|| And0
->getOpcode() != Instruction::And
)
1766 std::swap(And0
, And1
);
1767 ConstantInt
*C1
= dyn_cast
<ConstantInt
>(And0
->getOperand(1));
1770 uint32_t V0
= C0
->getZExtValue();
1771 uint32_t V1
= C1
->getZExtValue();
1772 if (V0
!= (V0
& V1
))
1775 return B
.CreateAnd(B
.CreateXor(And0
->getOperand(0), And1
), C0
);
1779 bool PolynomialMultiplyRecognize::recognize() {
1780 LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n"
1781 << *CurLoop
<< '\n');
1783 // - The loop must consist of a single block.
1784 // - The iteration count must be known at compile-time.
1785 // - The loop must have an induction variable starting from 0, and
1786 // incremented in each iteration of the loop.
1787 BasicBlock
*LoopB
= CurLoop
->getHeader();
1788 LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB
);
1790 if (LoopB
!= CurLoop
->getLoopLatch())
1792 BasicBlock
*ExitB
= CurLoop
->getExitBlock();
1793 if (ExitB
== nullptr)
1795 BasicBlock
*EntryB
= CurLoop
->getLoopPreheader();
1796 if (EntryB
== nullptr)
1799 unsigned IterCount
= 0;
1800 const SCEV
*CT
= SE
.getBackedgeTakenCount(CurLoop
);
1801 if (isa
<SCEVCouldNotCompute
>(CT
))
1803 if (auto *CV
= dyn_cast
<SCEVConstant
>(CT
))
1804 IterCount
= CV
->getValue()->getZExtValue() + 1;
1806 Value
*CIV
= getCountIV(LoopB
);
1809 PV
.IterCount
= IterCount
;
1810 LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV
<< "\nIterCount: " << IterCount
1813 setupPreSimplifier(PreSimp
);
1815 // Perform a preliminary scan of select instructions to see if any of them
1816 // looks like a generator of the polynomial multiply steps. Assume that a
1817 // loop can only contain a single transformable operation, so stop the
1818 // traversal after the first reasonable candidate was found.
1819 // XXX: Currently this approach can modify the loop before being 100% sure
1820 // that the transformation can be carried out.
1821 bool FoundPreScan
= false;
1822 auto FeedsPHI
= [LoopB
](const Value
*V
) -> bool {
1823 for (const Value
*U
: V
->users()) {
1824 if (const auto *P
= dyn_cast
<const PHINode
>(U
))
1825 if (P
->getParent() == LoopB
)
1830 for (Instruction
&In
: *LoopB
) {
1831 SelectInst
*SI
= dyn_cast
<SelectInst
>(&In
);
1832 if (!SI
|| !FeedsPHI(SI
))
1835 Simplifier::Context
C(SI
);
1836 Value
*T
= PreSimp
.simplify(C
);
1837 SelectInst
*SelI
= (T
&& isa
<SelectInst
>(T
)) ? cast
<SelectInst
>(T
) : SI
;
1838 LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C
, SelI
) << '\n');
1839 if (scanSelect(SelI
, LoopB
, EntryB
, CIV
, PV
, true)) {
1840 FoundPreScan
= true;
1842 Value
*NewSel
= C
.materialize(LoopB
, SI
->getIterator());
1843 SI
->replaceAllUsesWith(NewSel
);
1844 RecursivelyDeleteTriviallyDeadInstructions(SI
, &TLI
);
1850 if (!FoundPreScan
) {
1851 LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n");
1856 // The right shift version actually only returns the higher bits of
1857 // the result (each iteration discards the LSB). If we want to convert it
1858 // to a left-shifting loop, the working data type must be at least as
1859 // wide as the target's pmpy instruction.
1860 if (!promoteTypes(LoopB
, ExitB
))
1862 // Run post-promotion simplifications.
1863 Simplifier PostSimp
;
1864 setupPostSimplifier(PostSimp
);
1865 for (Instruction
&In
: *LoopB
) {
1866 SelectInst
*SI
= dyn_cast
<SelectInst
>(&In
);
1867 if (!SI
|| !FeedsPHI(SI
))
1869 Simplifier::Context
C(SI
);
1870 Value
*T
= PostSimp
.simplify(C
);
1871 SelectInst
*SelI
= dyn_cast_or_null
<SelectInst
>(T
);
1873 Value
*NewSel
= C
.materialize(LoopB
, SI
->getIterator());
1874 SI
->replaceAllUsesWith(NewSel
);
1875 RecursivelyDeleteTriviallyDeadInstructions(SI
, &TLI
);
1880 if (!convertShiftsToLeft(LoopB
, ExitB
, IterCount
))
1882 cleanupLoopBody(LoopB
);
1885 // Scan the loop again, find the generating select instruction.
1886 bool FoundScan
= false;
1887 for (Instruction
&In
: *LoopB
) {
1888 SelectInst
*SelI
= dyn_cast
<SelectInst
>(&In
);
1891 LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI
<< '\n');
1892 FoundScan
= scanSelect(SelI
, LoopB
, EntryB
, CIV
, PV
, false);
1899 StringRef PP
= (PV
.M
? "(P+M)" : "P");
1901 dbgs() << "Found pmpy idiom: R = " << PP
<< ".Q\n";
1903 dbgs() << "Found inverse pmpy idiom: R = (" << PP
<< "/Q).Q) + "
1905 dbgs() << " Res:" << *PV
.Res
<< "\n P:" << *PV
.P
<< "\n";
1907 dbgs() << " M:" << *PV
.M
<< "\n";
1908 dbgs() << " Q:" << *PV
.Q
<< "\n";
1909 dbgs() << " Iteration count:" << PV
.IterCount
<< "\n";
1912 BasicBlock::iterator
At(EntryB
->getTerminator());
1913 Value
*PM
= generate(At
, PV
);
1917 if (PM
->getType() != PV
.Res
->getType())
1918 PM
= IRBuilder
<>(&*At
).CreateIntCast(PM
, PV
.Res
->getType(), false);
1920 PV
.Res
->replaceAllUsesWith(PM
);
1921 PV
.Res
->eraseFromParent();
1925 int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr
*S
) {
1926 if (const SCEVConstant
*SC
= dyn_cast
<SCEVConstant
>(S
->getOperand(1)))
1927 return SC
->getAPInt().getSExtValue();
1931 bool HexagonLoopIdiomRecognize::isLegalStore(Loop
*CurLoop
, StoreInst
*SI
) {
1932 // Allow volatile stores if HexagonVolatileMemcpy is enabled.
1933 if (!(SI
->isVolatile() && HexagonVolatileMemcpy
) && !SI
->isSimple())
1936 Value
*StoredVal
= SI
->getValueOperand();
1937 Value
*StorePtr
= SI
->getPointerOperand();
1939 // Reject stores that are so large that they overflow an unsigned.
1940 uint64_t SizeInBits
= DL
->getTypeSizeInBits(StoredVal
->getType());
1941 if ((SizeInBits
& 7) || (SizeInBits
>> 32) != 0)
1944 // See if the pointer expression is an AddRec like {base,+,1} on the current
1945 // loop, which indicates a strided store. If we have something else, it's a
1946 // random store we can't handle.
1947 auto *StoreEv
= dyn_cast
<SCEVAddRecExpr
>(SE
->getSCEV(StorePtr
));
1948 if (!StoreEv
|| StoreEv
->getLoop() != CurLoop
|| !StoreEv
->isAffine())
1951 // Check to see if the stride matches the size of the store. If so, then we
1952 // know that every byte is touched in the loop.
1953 int Stride
= getSCEVStride(StoreEv
);
1956 unsigned StoreSize
= DL
->getTypeStoreSize(SI
->getValueOperand()->getType());
1957 if (StoreSize
!= unsigned(std::abs(Stride
)))
1960 // The store must be feeding a non-volatile load.
1961 LoadInst
*LI
= dyn_cast
<LoadInst
>(SI
->getValueOperand());
1962 if (!LI
|| !LI
->isSimple())
1965 // See if the pointer expression is an AddRec like {base,+,1} on the current
1966 // loop, which indicates a strided load. If we have something else, it's a
1967 // random load we can't handle.
1968 Value
*LoadPtr
= LI
->getPointerOperand();
1969 auto *LoadEv
= dyn_cast
<SCEVAddRecExpr
>(SE
->getSCEV(LoadPtr
));
1970 if (!LoadEv
|| LoadEv
->getLoop() != CurLoop
|| !LoadEv
->isAffine())
1973 // The store and load must share the same stride.
1974 if (StoreEv
->getOperand(1) != LoadEv
->getOperand(1))
1977 // Success. This store can be converted into a memcpy.
1981 /// mayLoopAccessLocation - Return true if the specified loop might access the
1982 /// specified pointer location, which is a loop-strided access. The 'Access'
1983 /// argument specifies what the verboten forms of access are (read or write).
1985 mayLoopAccessLocation(Value
*Ptr
, ModRefInfo Access
, Loop
*L
,
1986 const SCEV
*BECount
, unsigned StoreSize
,
1988 SmallPtrSetImpl
<Instruction
*> &Ignored
) {
1989 // Get the location that may be stored across the loop. Since the access
1990 // is strided positively through memory, we say that the modified location
1991 // starts at the pointer and has infinite size.
1992 LocationSize AccessSize
= LocationSize::afterPointer();
1994 // If the loop iterates a fixed number of times, we can refine the access
1995 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize
1996 if (const SCEVConstant
*BECst
= dyn_cast
<SCEVConstant
>(BECount
))
1997 AccessSize
= LocationSize::precise((BECst
->getValue()->getZExtValue() + 1) *
2000 // TODO: For this to be really effective, we have to dive into the pointer
2001 // operand in the store. Store to &A[i] of 100 will always return may alias
2002 // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
2003 // which will then no-alias a store to &A[100].
2004 MemoryLocation
StoreLoc(Ptr
, AccessSize
);
2006 for (auto *B
: L
->blocks())
2008 if (Ignored
.count(&I
) == 0 &&
2010 intersectModRef(AA
.getModRefInfo(&I
, StoreLoc
), Access
)))
2016 void HexagonLoopIdiomRecognize::collectStores(Loop
*CurLoop
, BasicBlock
*BB
,
2017 SmallVectorImpl
<StoreInst
*> &Stores
) {
2019 for (Instruction
&I
: *BB
)
2020 if (StoreInst
*SI
= dyn_cast
<StoreInst
>(&I
))
2021 if (isLegalStore(CurLoop
, SI
))
2022 Stores
.push_back(SI
);
2025 bool HexagonLoopIdiomRecognize::processCopyingStore(Loop
*CurLoop
,
2026 StoreInst
*SI
, const SCEV
*BECount
) {
2027 assert((SI
->isSimple() || (SI
->isVolatile() && HexagonVolatileMemcpy
)) &&
2028 "Expected only non-volatile stores, or Hexagon-specific memcpy"
2029 "to volatile destination.");
2031 Value
*StorePtr
= SI
->getPointerOperand();
2032 auto *StoreEv
= cast
<SCEVAddRecExpr
>(SE
->getSCEV(StorePtr
));
2033 unsigned Stride
= getSCEVStride(StoreEv
);
2034 unsigned StoreSize
= DL
->getTypeStoreSize(SI
->getValueOperand()->getType());
2035 if (Stride
!= StoreSize
)
2038 // See if the pointer expression is an AddRec like {base,+,1} on the current
2039 // loop, which indicates a strided load. If we have something else, it's a
2040 // random load we can't handle.
2041 auto *LI
= cast
<LoadInst
>(SI
->getValueOperand());
2042 auto *LoadEv
= cast
<SCEVAddRecExpr
>(SE
->getSCEV(LI
->getPointerOperand()));
2044 // The trip count of the loop and the base pointer of the addrec SCEV is
2045 // guaranteed to be loop invariant, which means that it should dominate the
2046 // header. This allows us to insert code for it in the preheader.
2047 BasicBlock
*Preheader
= CurLoop
->getLoopPreheader();
2048 Instruction
*ExpPt
= Preheader
->getTerminator();
2049 IRBuilder
<> Builder(ExpPt
);
2050 SCEVExpander
Expander(*SE
, *DL
, "hexagon-loop-idiom");
2052 Type
*IntPtrTy
= Builder
.getIntPtrTy(*DL
, SI
->getPointerAddressSpace());
2054 // Okay, we have a strided store "p[i]" of a loaded value. We can turn
2055 // this into a memcpy/memmove in the loop preheader now if we want. However,
2056 // this would be unsafe to do if there is anything else in the loop that may
2057 // read or write the memory region we're storing to. For memcpy, this
2058 // includes the load that feeds the stores. Check for an alias by generating
2059 // the base address and checking everything.
2060 Value
*StoreBasePtr
= Expander
.expandCodeFor(StoreEv
->getStart(),
2061 Builder
.getInt8PtrTy(SI
->getPointerAddressSpace()), ExpPt
);
2062 Value
*LoadBasePtr
= nullptr;
2064 bool Overlap
= false;
2065 bool DestVolatile
= SI
->isVolatile();
2066 Type
*BECountTy
= BECount
->getType();
2069 // The trip count must fit in i32, since it is the type of the "num_words"
2070 // argument to hexagon_memcpy_forward_vp4cp4n2.
2071 if (StoreSize
!= 4 || DL
->getTypeSizeInBits(BECountTy
) > 32) {
2073 // If we generated new code for the base pointer, clean up.
2075 if (StoreBasePtr
&& (LoadBasePtr
!= StoreBasePtr
)) {
2076 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr
, TLI
);
2077 StoreBasePtr
= nullptr;
2080 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr
, TLI
);
2081 LoadBasePtr
= nullptr;
2087 SmallPtrSet
<Instruction
*, 2> Ignore1
;
2089 if (mayLoopAccessLocation(StoreBasePtr
, ModRefInfo::ModRef
, CurLoop
, BECount
,
2090 StoreSize
, *AA
, Ignore1
)) {
2091 // Check if the load is the offending instruction.
2093 if (mayLoopAccessLocation(StoreBasePtr
, ModRefInfo::ModRef
, CurLoop
,
2094 BECount
, StoreSize
, *AA
, Ignore1
)) {
2095 // Still bad. Nothing we can do.
2096 goto CleanupAndExit
;
2098 // It worked with the load ignored.
2103 if (DisableMemcpyIdiom
|| !HasMemcpy
)
2104 goto CleanupAndExit
;
2106 // Don't generate memmove if this function will be inlined. This is
2107 // because the caller will undergo this transformation after inlining.
2108 Function
*Func
= CurLoop
->getHeader()->getParent();
2109 if (Func
->hasFnAttribute(Attribute::AlwaysInline
))
2110 goto CleanupAndExit
;
2112 // In case of a memmove, the call to memmove will be executed instead
2113 // of the loop, so we need to make sure that there is nothing else in
2114 // the loop than the load, store and instructions that these two depend
2116 SmallVector
<Instruction
*,2> Insts
;
2117 Insts
.push_back(SI
);
2118 Insts
.push_back(LI
);
2119 if (!coverLoop(CurLoop
, Insts
))
2120 goto CleanupAndExit
;
2122 if (DisableMemmoveIdiom
|| !HasMemmove
)
2123 goto CleanupAndExit
;
2124 bool IsNested
= CurLoop
->getParentLoop() != nullptr;
2125 if (IsNested
&& OnlyNonNestedMemmove
)
2126 goto CleanupAndExit
;
2129 // For a memcpy, we have to make sure that the input array is not being
2130 // mutated by the loop.
2131 LoadBasePtr
= Expander
.expandCodeFor(LoadEv
->getStart(),
2132 Builder
.getInt8PtrTy(LI
->getPointerAddressSpace()), ExpPt
);
2134 SmallPtrSet
<Instruction
*, 2> Ignore2
;
2136 if (mayLoopAccessLocation(LoadBasePtr
, ModRefInfo::Mod
, CurLoop
, BECount
,
2137 StoreSize
, *AA
, Ignore2
))
2138 goto CleanupAndExit
;
2140 // Check the stride.
2141 bool StridePos
= getSCEVStride(LoadEv
) >= 0;
2143 // Currently, the volatile memcpy only emulates traversing memory forward.
2144 if (!StridePos
&& DestVolatile
)
2145 goto CleanupAndExit
;
2147 bool RuntimeCheck
= (Overlap
|| DestVolatile
);
2151 // The runtime check needs a single exit block.
2152 SmallVector
<BasicBlock
*, 8> ExitBlocks
;
2153 CurLoop
->getUniqueExitBlocks(ExitBlocks
);
2154 if (ExitBlocks
.size() != 1)
2155 goto CleanupAndExit
;
2156 ExitB
= ExitBlocks
[0];
2159 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to
2160 // pointer size if it isn't already.
2161 LLVMContext
&Ctx
= SI
->getContext();
2162 BECount
= SE
->getTruncateOrZeroExtend(BECount
, IntPtrTy
);
2163 DebugLoc DLoc
= SI
->getDebugLoc();
2165 const SCEV
*NumBytesS
=
2166 SE
->getAddExpr(BECount
, SE
->getOne(IntPtrTy
), SCEV::FlagNUW
);
2168 NumBytesS
= SE
->getMulExpr(NumBytesS
, SE
->getConstant(IntPtrTy
, StoreSize
),
2170 Value
*NumBytes
= Expander
.expandCodeFor(NumBytesS
, IntPtrTy
, ExpPt
);
2171 if (Instruction
*In
= dyn_cast
<Instruction
>(NumBytes
))
2172 if (Value
*Simp
= SimplifyInstruction(In
, {*DL
, TLI
, DT
}))
2178 unsigned Threshold
= RuntimeMemSizeThreshold
;
2179 if (ConstantInt
*CI
= dyn_cast
<ConstantInt
>(NumBytes
)) {
2180 uint64_t C
= CI
->getZExtValue();
2181 if (Threshold
!= 0 && C
< Threshold
)
2182 goto CleanupAndExit
;
2183 if (C
< CompileTimeMemSizeThreshold
)
2184 goto CleanupAndExit
;
2187 BasicBlock
*Header
= CurLoop
->getHeader();
2188 Function
*Func
= Header
->getParent();
2189 Loop
*ParentL
= LF
->getLoopFor(Preheader
);
2190 StringRef HeaderName
= Header
->getName();
2192 // Create a new (empty) preheader, and update the PHI nodes in the
2193 // header to use the new preheader.
2194 BasicBlock
*NewPreheader
= BasicBlock::Create(Ctx
, HeaderName
+".rtli.ph",
2197 ParentL
->addBasicBlockToLoop(NewPreheader
, *LF
);
2198 IRBuilder
<>(NewPreheader
).CreateBr(Header
);
2199 for (auto &In
: *Header
) {
2200 PHINode
*PN
= dyn_cast
<PHINode
>(&In
);
2203 int bx
= PN
->getBasicBlockIndex(Preheader
);
2205 PN
->setIncomingBlock(bx
, NewPreheader
);
2207 DT
->addNewBlock(NewPreheader
, Preheader
);
2208 DT
->changeImmediateDominator(Header
, NewPreheader
);
2210 // Check for safe conditions to execute memmove.
2211 // If stride is positive, copying things from higher to lower addresses
2212 // is equivalent to memmove. For negative stride, it's the other way
2213 // around. Copying forward in memory with positive stride may not be
2214 // same as memmove since we may be copying values that we just stored
2215 // in some previous iteration.
2216 Value
*LA
= Builder
.CreatePtrToInt(LoadBasePtr
, IntPtrTy
);
2217 Value
*SA
= Builder
.CreatePtrToInt(StoreBasePtr
, IntPtrTy
);
2218 Value
*LowA
= StridePos
? SA
: LA
;
2219 Value
*HighA
= StridePos
? LA
: SA
;
2220 Value
*CmpA
= Builder
.CreateICmpULT(LowA
, HighA
);
2223 // Check for distance between pointers. Since the case LowA < HighA
2224 // is checked for above, assume LowA >= HighA.
2225 Value
*Dist
= Builder
.CreateSub(LowA
, HighA
);
2226 Value
*CmpD
= Builder
.CreateICmpSLE(NumBytes
, Dist
);
2227 Value
*CmpEither
= Builder
.CreateOr(Cond
, CmpD
);
2230 if (Threshold
!= 0) {
2231 Type
*Ty
= NumBytes
->getType();
2232 Value
*Thr
= ConstantInt::get(Ty
, Threshold
);
2233 Value
*CmpB
= Builder
.CreateICmpULT(Thr
, NumBytes
);
2234 Value
*CmpBoth
= Builder
.CreateAnd(Cond
, CmpB
);
2237 BasicBlock
*MemmoveB
= BasicBlock::Create(Ctx
, Header
->getName()+".rtli",
2238 Func
, NewPreheader
);
2240 ParentL
->addBasicBlockToLoop(MemmoveB
, *LF
);
2241 Instruction
*OldT
= Preheader
->getTerminator();
2242 Builder
.CreateCondBr(Cond
, MemmoveB
, NewPreheader
);
2243 OldT
->eraseFromParent();
2244 Preheader
->setName(Preheader
->getName()+".old");
2245 DT
->addNewBlock(MemmoveB
, Preheader
);
2246 // Find the new immediate dominator of the exit block.
2247 BasicBlock
*ExitD
= Preheader
;
2248 for (BasicBlock
*PB
: predecessors(ExitB
)) {
2249 ExitD
= DT
->findNearestCommonDominator(ExitD
, PB
);
2253 // If the prior immediate dominator of ExitB was dominated by the
2254 // old preheader, then the old preheader becomes the new immediate
2255 // dominator. Otherwise don't change anything (because the newly
2256 // added blocks are dominated by the old preheader).
2257 if (ExitD
&& DT
->dominates(Preheader
, ExitD
)) {
2258 DomTreeNode
*BN
= DT
->getNode(ExitB
);
2259 DomTreeNode
*DN
= DT
->getNode(ExitD
);
2263 // Add a call to memmove to the conditional block.
2264 IRBuilder
<> CondBuilder(MemmoveB
);
2265 CondBuilder
.CreateBr(ExitB
);
2266 CondBuilder
.SetInsertPoint(MemmoveB
->getTerminator());
2269 Type
*Int32Ty
= Type::getInt32Ty(Ctx
);
2270 Type
*Int32PtrTy
= Type::getInt32PtrTy(Ctx
);
2271 Type
*VoidTy
= Type::getVoidTy(Ctx
);
2272 Module
*M
= Func
->getParent();
2273 FunctionCallee Fn
= M
->getOrInsertFunction(
2274 HexagonVolatileMemcpyName
, VoidTy
, Int32PtrTy
, Int32PtrTy
, Int32Ty
);
2276 const SCEV
*OneS
= SE
->getConstant(Int32Ty
, 1);
2277 const SCEV
*BECount32
= SE
->getTruncateOrZeroExtend(BECount
, Int32Ty
);
2278 const SCEV
*NumWordsS
= SE
->getAddExpr(BECount32
, OneS
, SCEV::FlagNUW
);
2279 Value
*NumWords
= Expander
.expandCodeFor(NumWordsS
, Int32Ty
,
2280 MemmoveB
->getTerminator());
2281 if (Instruction
*In
= dyn_cast
<Instruction
>(NumWords
))
2282 if (Value
*Simp
= SimplifyInstruction(In
, {*DL
, TLI
, DT
}))
2285 Value
*Op0
= (StoreBasePtr
->getType() == Int32PtrTy
)
2287 : CondBuilder
.CreateBitCast(StoreBasePtr
, Int32PtrTy
);
2288 Value
*Op1
= (LoadBasePtr
->getType() == Int32PtrTy
)
2290 : CondBuilder
.CreateBitCast(LoadBasePtr
, Int32PtrTy
);
2291 NewCall
= CondBuilder
.CreateCall(Fn
, {Op0
, Op1
, NumWords
});
2293 NewCall
= CondBuilder
.CreateMemMove(
2294 StoreBasePtr
, SI
->getAlign(), LoadBasePtr
, LI
->getAlign(), NumBytes
);
2297 NewCall
= Builder
.CreateMemCpy(StoreBasePtr
, SI
->getAlign(), LoadBasePtr
,
2298 LI
->getAlign(), NumBytes
);
2299 // Okay, the memcpy has been formed. Zap the original store and
2300 // anything that feeds into it.
2301 RecursivelyDeleteTriviallyDeadInstructions(SI
, TLI
);
2304 NewCall
->setDebugLoc(DLoc
);
2306 LLVM_DEBUG(dbgs() << " Formed " << (Overlap
? "memmove: " : "memcpy: ")
2308 << " from load ptr=" << *LoadEv
<< " at: " << *LI
<< "\n"
2309 << " from store ptr=" << *StoreEv
<< " at: " << *SI
2315 // Check if the instructions in Insts, together with their dependencies
2316 // cover the loop in the sense that the loop could be safely eliminated once
2317 // the instructions in Insts are removed.
2318 bool HexagonLoopIdiomRecognize::coverLoop(Loop
*L
,
2319 SmallVectorImpl
<Instruction
*> &Insts
) const {
2320 SmallSet
<BasicBlock
*,8> LoopBlocks
;
2321 for (auto *B
: L
->blocks())
2322 LoopBlocks
.insert(B
);
2324 SetVector
<Instruction
*> Worklist(Insts
.begin(), Insts
.end());
2326 // Collect all instructions from the loop that the instructions in Insts
2327 // depend on (plus their dependencies, etc.). These instructions will
2328 // constitute the expression trees that feed those in Insts, but the trees
2329 // will be limited only to instructions contained in the loop.
2330 for (unsigned i
= 0; i
< Worklist
.size(); ++i
) {
2331 Instruction
*In
= Worklist
[i
];
2332 for (auto I
= In
->op_begin(), E
= In
->op_end(); I
!= E
; ++I
) {
2333 Instruction
*OpI
= dyn_cast
<Instruction
>(I
);
2336 BasicBlock
*PB
= OpI
->getParent();
2337 if (!LoopBlocks
.count(PB
))
2339 Worklist
.insert(OpI
);
2343 // Scan all instructions in the loop, if any of them have a user outside
2344 // of the loop, or outside of the expressions collected above, then either
2345 // the loop has a side-effect visible outside of it, or there are
2346 // instructions in it that are not involved in the original set Insts.
2347 for (auto *B
: L
->blocks()) {
2348 for (auto &In
: *B
) {
2349 if (isa
<BranchInst
>(In
) || isa
<DbgInfoIntrinsic
>(In
))
2351 if (!Worklist
.count(&In
) && In
.mayHaveSideEffects())
2353 for (auto K
: In
.users()) {
2354 Instruction
*UseI
= dyn_cast
<Instruction
>(K
);
2357 BasicBlock
*UseB
= UseI
->getParent();
2358 if (LF
->getLoopFor(UseB
) != L
)
2367 /// runOnLoopBlock - Process the specified block, which lives in a counted loop
2368 /// with the specified backedge count. This block is known to be in the current
2369 /// loop and not in any subloops.
2370 bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop
*CurLoop
, BasicBlock
*BB
,
2371 const SCEV
*BECount
, SmallVectorImpl
<BasicBlock
*> &ExitBlocks
) {
2372 // We can only promote stores in this block if they are unconditionally
2373 // executed in the loop. For a block to be unconditionally executed, it has
2374 // to dominate all the exit blocks of the loop. Verify this now.
2375 auto DominatedByBB
= [this,BB
] (BasicBlock
*EB
) -> bool {
2376 return DT
->dominates(BB
, EB
);
2378 if (!all_of(ExitBlocks
, DominatedByBB
))
2381 bool MadeChange
= false;
2382 // Look for store instructions, which may be optimized to memset/memcpy.
2383 SmallVector
<StoreInst
*,8> Stores
;
2384 collectStores(CurLoop
, BB
, Stores
);
2386 // Optimize the store into a memcpy, if it feeds an similarly strided load.
2387 for (auto &SI
: Stores
)
2388 MadeChange
|= processCopyingStore(CurLoop
, SI
, BECount
);
2393 bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop
*L
) {
2394 PolynomialMultiplyRecognize
PMR(L
, *DL
, *DT
, *TLI
, *SE
);
2395 if (PMR
.recognize())
2398 if (!HasMemcpy
&& !HasMemmove
)
2401 const SCEV
*BECount
= SE
->getBackedgeTakenCount(L
);
2402 assert(!isa
<SCEVCouldNotCompute
>(BECount
) &&
2403 "runOnCountableLoop() called on a loop without a predictable"
2404 "backedge-taken count");
2406 SmallVector
<BasicBlock
*, 8> ExitBlocks
;
2407 L
->getUniqueExitBlocks(ExitBlocks
);
2409 bool Changed
= false;
2411 // Scan all the blocks in the loop that are not in subloops.
2412 for (auto *BB
: L
->getBlocks()) {
2413 // Ignore blocks in subloops.
2414 if (LF
->getLoopFor(BB
) != L
)
2416 Changed
|= runOnLoopBlock(L
, BB
, BECount
, ExitBlocks
);
2422 bool HexagonLoopIdiomRecognize::run(Loop
*L
) {
2423 const Module
&M
= *L
->getHeader()->getParent()->getParent();
2424 if (Triple(M
.getTargetTriple()).getArch() != Triple::hexagon
)
2427 // If the loop could not be converted to canonical form, it must have an
2428 // indirectbr in it, just give up.
2429 if (!L
->getLoopPreheader())
2432 // Disable loop idiom recognition if the function's name is a common idiom.
2433 StringRef Name
= L
->getHeader()->getParent()->getName();
2434 if (Name
== "memset" || Name
== "memcpy" || Name
== "memmove")
2437 DL
= &L
->getHeader()->getModule()->getDataLayout();
2439 HasMemcpy
= TLI
->has(LibFunc_memcpy
);
2440 HasMemmove
= TLI
->has(LibFunc_memmove
);
2442 if (SE
->hasLoopInvariantBackedgeTakenCount(L
))
2443 return runOnCountableLoop(L
);
2447 bool HexagonLoopIdiomRecognizeLegacyPass::runOnLoop(Loop
*L
,
2448 LPPassManager
&LPM
) {
2452 auto *AA
= &getAnalysis
<AAResultsWrapperPass
>().getAAResults();
2453 auto *DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
2454 auto *LF
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
2455 auto *TLI
= &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(
2456 *L
->getHeader()->getParent());
2457 auto *SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
2458 return HexagonLoopIdiomRecognize(AA
, DT
, LF
, TLI
, SE
).run(L
);
2461 Pass
*llvm::createHexagonLoopIdiomPass() {
2462 return new HexagonLoopIdiomRecognizeLegacyPass();
2466 HexagonLoopIdiomRecognitionPass::run(Loop
&L
, LoopAnalysisManager
&AM
,
2467 LoopStandardAnalysisResults
&AR
,
2469 return HexagonLoopIdiomRecognize(&AR
.AA
, &AR
.DT
, &AR
.LI
, &AR
.TLI
, &AR
.SE
)
2471 ? getLoopPassPreservedAnalyses()
2472 : PreservedAnalyses::all();