1 //===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass implements a pass that recognizes certain loop idioms and
10 // transforms them into more optimized versions of the same loop. In cases
11 // where this happens, it can be a significant performance win.
13 // We currently only recognize one loop that finds the first mismatched byte
14 // in an array and returns the index, i.e. something like:
21 // In this example we can actually vectorize the loop despite the early exit,
22 // although the loop vectorizer does not support it. It requires some extra
23 // checks to deal with the possibility of faulting loads when crossing page
24 // boundaries. However, even with these checks it is still profitable to do the
27 //===----------------------------------------------------------------------===//
31 // * Add support for the inverse case where we scan for a matching element.
32 // * Permit 64-bit induction variable types.
33 // * Recognize loops that increment the IV *after* comparing bytes.
34 // * Allow 32-bit sign-extends of the IV used by the GEP.
36 //===----------------------------------------------------------------------===//
38 #include "AArch64LoopIdiomTransform.h"
39 #include "llvm/Analysis/DomTreeUpdater.h"
40 #include "llvm/Analysis/LoopPass.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/Dominators.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Intrinsics.h"
45 #include "llvm/IR/MDBuilder.h"
46 #include "llvm/IR/PatternMatch.h"
47 #include "llvm/InitializePasses.h"
48 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
51 using namespace PatternMatch
;
53 #define DEBUG_TYPE "aarch64-loop-idiom-transform"
56 DisableAll("disable-aarch64-lit-all", cl::Hidden
, cl::init(false),
57 cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
59 static cl::opt
<bool> DisableByteCmp(
60 "disable-aarch64-lit-bytecmp", cl::Hidden
, cl::init(false),
61 cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
62 "not convert byte-compare loop(s)."));
64 static cl::opt
<bool> VerifyLoops(
65 "aarch64-lit-verify", cl::Hidden
, cl::init(false),
66 cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass."));
70 void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry
&);
71 Pass
*createAArch64LoopIdiomTransformPass();
73 } // end namespace llvm
77 class AArch64LoopIdiomTransform
{
78 Loop
*CurLoop
= nullptr;
81 const TargetTransformInfo
*TTI
;
85 explicit AArch64LoopIdiomTransform(DominatorTree
*DT
, LoopInfo
*LI
,
86 const TargetTransformInfo
*TTI
,
88 : DT(DT
), LI(LI
), TTI(TTI
), DL(DL
) {}
93 /// \name Countable Loop Idiom Handling
96 bool runOnCountableLoop();
97 bool runOnLoopBlock(BasicBlock
*BB
, const SCEV
*BECount
,
98 SmallVectorImpl
<BasicBlock
*> &ExitBlocks
);
100 bool recognizeByteCompare();
101 Value
*expandFindMismatch(IRBuilder
<> &Builder
, DomTreeUpdater
&DTU
,
102 GetElementPtrInst
*GEPA
, GetElementPtrInst
*GEPB
,
103 Instruction
*Index
, Value
*Start
, Value
*MaxLen
);
104 void transformByteCompare(GetElementPtrInst
*GEPA
, GetElementPtrInst
*GEPB
,
105 PHINode
*IndPhi
, Value
*MaxLen
, Instruction
*Index
,
106 Value
*Start
, bool IncIdx
, BasicBlock
*FoundBB
,
111 class AArch64LoopIdiomTransformLegacyPass
: public LoopPass
{
115 explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID
) {
116 initializeAArch64LoopIdiomTransformLegacyPassPass(
117 *PassRegistry::getPassRegistry());
120 StringRef
getPassName() const override
{
121 return "Transform AArch64-specific loop idioms";
124 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
125 AU
.addRequired
<LoopInfoWrapperPass
>();
126 AU
.addRequired
<DominatorTreeWrapperPass
>();
127 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
130 bool runOnLoop(Loop
*L
, LPPassManager
&LPM
) override
;
133 bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop
*L
,
134 LPPassManager
&LPM
) {
139 auto *DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
140 auto *LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
141 auto &TTI
= getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(
142 *L
->getHeader()->getParent());
143 return AArch64LoopIdiomTransform(
144 DT
, LI
, &TTI
, &L
->getHeader()->getModule()->getDataLayout())
148 } // end anonymous namespace
150 char AArch64LoopIdiomTransformLegacyPass::ID
= 0;
152 INITIALIZE_PASS_BEGIN(
153 AArch64LoopIdiomTransformLegacyPass
, "aarch64-lit",
154 "Transform specific loop idioms into optimized vector forms", false, false)
155 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
156 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
157 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass
)
158 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
159 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
161 AArch64LoopIdiomTransformLegacyPass
, "aarch64-lit",
162 "Transform specific loop idioms into optimized vector forms", false, false)
164 Pass
*llvm::createAArch64LoopIdiomTransformPass() {
165 return new AArch64LoopIdiomTransformLegacyPass();
169 AArch64LoopIdiomTransformPass::run(Loop
&L
, LoopAnalysisManager
&AM
,
170 LoopStandardAnalysisResults
&AR
,
173 return PreservedAnalyses::all();
175 const auto *DL
= &L
.getHeader()->getModule()->getDataLayout();
177 AArch64LoopIdiomTransform
LIT(&AR
.DT
, &AR
.LI
, &AR
.TTI
, DL
);
179 return PreservedAnalyses::all();
181 return PreservedAnalyses::none();
184 //===----------------------------------------------------------------------===//
186 // Implementation of AArch64LoopIdiomTransform
188 //===----------------------------------------------------------------------===//
190 bool AArch64LoopIdiomTransform::run(Loop
*L
) {
193 Function
&F
= *L
->getHeader()->getParent();
194 if (DisableAll
|| F
.hasOptSize())
197 if (F
.hasFnAttribute(Attribute::NoImplicitFloat
)) {
198 LLVM_DEBUG(dbgs() << DEBUG_TYPE
<< " is disabled on " << F
.getName()
199 << " due to its NoImplicitFloat attribute");
203 // If the loop could not be converted to canonical form, it must have an
204 // indirectbr in it, just give up.
205 if (!L
->getLoopPreheader())
208 LLVM_DEBUG(dbgs() << DEBUG_TYPE
" Scanning: F[" << F
.getName() << "] Loop %"
209 << CurLoop
->getHeader()->getName() << "\n");
211 return recognizeByteCompare();
214 bool AArch64LoopIdiomTransform::recognizeByteCompare() {
215 // Currently the transformation only works on scalable vector types, although
216 // there is no fundamental reason why it cannot be made to work for fixed
219 // We also need to know the minimum page size for the target in order to
220 // generate runtime memory checks to ensure the vector version won't fault.
221 if (!TTI
->supportsScalableVectors() || !TTI
->getMinPageSize().has_value() ||
225 BasicBlock
*Header
= CurLoop
->getHeader();
227 // In AArch64LoopIdiomTransform::run we have already checked that the loop
228 // has a preheader so we can assume it's in a canonical form.
229 if (CurLoop
->getNumBackEdges() != 1 || CurLoop
->getNumBlocks() != 2)
232 PHINode
*PN
= dyn_cast
<PHINode
>(&Header
->front());
233 if (!PN
|| PN
->getNumIncomingValues() != 2)
236 auto LoopBlocks
= CurLoop
->getBlocks();
237 // The first block in the loop should contain only 4 instructions, e.g.
240 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
241 // %inc = add i32 %res.phi, 1
242 // %cmp.not = icmp eq i32 %inc, %n
243 // br i1 %cmp.not, label %while.end, label %while.body
245 auto CondBBInsts
= LoopBlocks
[0]->instructionsWithoutDebug();
246 if (std::distance(CondBBInsts
.begin(), CondBBInsts
.end()) > 4)
249 // The second block should contain 7 instructions, e.g.
252 // %idx = zext i32 %inc to i64
253 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
254 // %load.a = load i8, ptr %idx.a
255 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
256 // %load.b = load i8, ptr %idx.b
257 // %cmp.not.ld = icmp eq i8 %load.a, %load.b
258 // br i1 %cmp.not.ld, label %while.cond, label %while.end
260 auto LoopBBInsts
= LoopBlocks
[1]->instructionsWithoutDebug();
261 if (std::distance(LoopBBInsts
.begin(), LoopBBInsts
.end()) > 7)
264 // The incoming value to the PHI node from the loop should be an add of 1.
265 Value
*StartIdx
= nullptr;
266 Instruction
*Index
= nullptr;
267 if (!CurLoop
->contains(PN
->getIncomingBlock(0))) {
268 StartIdx
= PN
->getIncomingValue(0);
269 Index
= dyn_cast
<Instruction
>(PN
->getIncomingValue(1));
271 StartIdx
= PN
->getIncomingValue(1);
272 Index
= dyn_cast
<Instruction
>(PN
->getIncomingValue(0));
275 // Limit to 32-bit types for now
276 if (!Index
|| !Index
->getType()->isIntegerTy(32) ||
277 !match(Index
, m_c_Add(m_Specific(PN
), m_One())))
280 // If we match the pattern, PN and Index will be replaced with the result of
281 // the cttz.elts intrinsic. If any other instructions are used outside of
282 // the loop, we cannot replace it.
283 for (BasicBlock
*BB
: LoopBlocks
)
284 for (Instruction
&I
: *BB
)
285 if (&I
!= PN
&& &I
!= Index
)
286 for (User
*U
: I
.users())
287 if (!CurLoop
->contains(cast
<Instruction
>(U
)))
290 // Match the branch instruction for the header
291 ICmpInst::Predicate Pred
;
293 BasicBlock
*EndBB
, *WhileBB
;
294 if (!match(Header
->getTerminator(),
295 m_Br(m_ICmp(Pred
, m_Specific(Index
), m_Value(MaxLen
)),
296 m_BasicBlock(EndBB
), m_BasicBlock(WhileBB
))) ||
297 Pred
!= ICmpInst::Predicate::ICMP_EQ
|| !CurLoop
->contains(WhileBB
))
300 // WhileBB should contain the pattern of load & compare instructions. Match
301 // the pattern and find the GEP instructions used by the loads.
302 ICmpInst::Predicate WhilePred
;
305 Value
*LoadA
, *LoadB
;
306 if (!match(WhileBB
->getTerminator(),
307 m_Br(m_ICmp(WhilePred
, m_Value(LoadA
), m_Value(LoadB
)),
308 m_BasicBlock(TrueBB
), m_BasicBlock(FoundBB
))) ||
309 WhilePred
!= ICmpInst::Predicate::ICMP_EQ
|| !CurLoop
->contains(TrueBB
))
313 if (!match(LoadA
, m_Load(m_Value(A
))) || !match(LoadB
, m_Load(m_Value(B
))))
316 LoadInst
*LoadAI
= cast
<LoadInst
>(LoadA
);
317 LoadInst
*LoadBI
= cast
<LoadInst
>(LoadB
);
318 if (!LoadAI
->isSimple() || !LoadBI
->isSimple())
321 GetElementPtrInst
*GEPA
= dyn_cast
<GetElementPtrInst
>(A
);
322 GetElementPtrInst
*GEPB
= dyn_cast
<GetElementPtrInst
>(B
);
327 Value
*PtrA
= GEPA
->getPointerOperand();
328 Value
*PtrB
= GEPB
->getPointerOperand();
330 // Check we are loading i8 values from two loop invariant pointers
331 if (!CurLoop
->isLoopInvariant(PtrA
) || !CurLoop
->isLoopInvariant(PtrB
) ||
332 !GEPA
->getResultElementType()->isIntegerTy(8) ||
333 !GEPB
->getResultElementType()->isIntegerTy(8) ||
334 !LoadAI
->getType()->isIntegerTy(8) ||
335 !LoadBI
->getType()->isIntegerTy(8) || PtrA
== PtrB
)
338 // Check that the index to the GEPs is the index we found earlier
339 if (GEPA
->getNumIndices() > 1 || GEPB
->getNumIndices() > 1)
342 Value
*IdxA
= GEPA
->getOperand(GEPA
->getNumIndices());
343 Value
*IdxB
= GEPB
->getOperand(GEPB
->getNumIndices());
344 if (IdxA
!= IdxB
|| !match(IdxA
, m_ZExt(m_Specific(Index
))))
347 // We only ever expect the pre-incremented index value to be used inside the
349 if (!PN
->hasOneUse())
352 // Ensure that when the Found and End blocks are identical the PHIs have the
353 // supported format. We don't currently allow cases like this:
356 // br i1 %cmp.not, label %while.end, label %while.body
360 // br i1 %cmp.not2, label %while.cond, label %while.end
363 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
365 // Where the incoming values for %final_ptr are unique and from each of the
366 // loop blocks, but not actually defined in the loop. This requires extra
367 // work setting up the byte.compare block, i.e. by introducing a select to
368 // choose the correct value.
369 // TODO: We could add support for this in future.
370 if (FoundBB
== EndBB
) {
371 for (PHINode
&EndPN
: EndBB
->phis()) {
372 Value
*WhileCondVal
= EndPN
.getIncomingValueForBlock(Header
);
373 Value
*WhileBodyVal
= EndPN
.getIncomingValueForBlock(WhileBB
);
375 // The value of the index when leaving the while.cond block is always the
376 // same as the end value (MaxLen) so we permit either. The value when
377 // leaving the while.body block should only be the index. Otherwise for
378 // any other values we only allow ones that are same for both blocks.
379 if (WhileCondVal
!= WhileBodyVal
&&
380 ((WhileCondVal
!= Index
&& WhileCondVal
!= MaxLen
) ||
381 (WhileBodyVal
!= Index
)))
386 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
387 << *(EndBB
->getParent()) << "\n\n");
389 // The index is incremented before the GEP/Load pair so we need to
390 // add 1 to the start value.
391 transformByteCompare(GEPA
, GEPB
, PN
, MaxLen
, Index
, StartIdx
, /*IncIdx=*/true,
396 Value
*AArch64LoopIdiomTransform::expandFindMismatch(
397 IRBuilder
<> &Builder
, DomTreeUpdater
&DTU
, GetElementPtrInst
*GEPA
,
398 GetElementPtrInst
*GEPB
, Instruction
*Index
, Value
*Start
, Value
*MaxLen
) {
399 Value
*PtrA
= GEPA
->getPointerOperand();
400 Value
*PtrB
= GEPB
->getPointerOperand();
402 // Get the arguments and types for the intrinsic.
403 BasicBlock
*Preheader
= CurLoop
->getLoopPreheader();
404 BranchInst
*PHBranch
= cast
<BranchInst
>(Preheader
->getTerminator());
405 LLVMContext
&Ctx
= PHBranch
->getContext();
406 Type
*LoadType
= Type::getInt8Ty(Ctx
);
407 Type
*ResType
= Builder
.getInt32Ty();
409 // Split block in the original loop preheader.
410 BasicBlock
*EndBlock
=
411 SplitBlock(Preheader
, PHBranch
, DT
, LI
, nullptr, "mismatch_end");
413 // Create the blocks that we're going to need:
414 // 1. A block for checking the zero-extended length exceeds 0
415 // 2. A block to check that the start and end addresses of a given array
416 // lie on the same page.
417 // 3. The SVE loop preheader.
418 // 4. The first SVE loop block.
419 // 5. The SVE loop increment block.
420 // 6. A block we can jump to from the SVE loop when a mismatch is found.
421 // 7. The first block of the scalar loop itself, containing PHIs , loads
423 // 8. A scalar loop increment block to increment the PHIs and go back
426 BasicBlock
*MinItCheckBlock
= BasicBlock::Create(
427 Ctx
, "mismatch_min_it_check", EndBlock
->getParent(), EndBlock
);
429 // Update the terminator added by SplitBlock to branch to the first block
430 Preheader
->getTerminator()->setSuccessor(0, MinItCheckBlock
);
432 BasicBlock
*MemCheckBlock
= BasicBlock::Create(
433 Ctx
, "mismatch_mem_check", EndBlock
->getParent(), EndBlock
);
435 BasicBlock
*SVELoopPreheaderBlock
= BasicBlock::Create(
436 Ctx
, "mismatch_sve_loop_preheader", EndBlock
->getParent(), EndBlock
);
438 BasicBlock
*SVELoopStartBlock
= BasicBlock::Create(
439 Ctx
, "mismatch_sve_loop", EndBlock
->getParent(), EndBlock
);
441 BasicBlock
*SVELoopIncBlock
= BasicBlock::Create(
442 Ctx
, "mismatch_sve_loop_inc", EndBlock
->getParent(), EndBlock
);
444 BasicBlock
*SVELoopMismatchBlock
= BasicBlock::Create(
445 Ctx
, "mismatch_sve_loop_found", EndBlock
->getParent(), EndBlock
);
447 BasicBlock
*LoopPreHeaderBlock
= BasicBlock::Create(
448 Ctx
, "mismatch_loop_pre", EndBlock
->getParent(), EndBlock
);
450 BasicBlock
*LoopStartBlock
=
451 BasicBlock::Create(Ctx
, "mismatch_loop", EndBlock
->getParent(), EndBlock
);
453 BasicBlock
*LoopIncBlock
= BasicBlock::Create(
454 Ctx
, "mismatch_loop_inc", EndBlock
->getParent(), EndBlock
);
456 DTU
.applyUpdates({{DominatorTree::Insert
, Preheader
, MinItCheckBlock
},
457 {DominatorTree::Delete
, Preheader
, EndBlock
}});
459 // Update LoopInfo with the new SVE & scalar loops.
460 auto SVELoop
= LI
->AllocateLoop();
461 auto ScalarLoop
= LI
->AllocateLoop();
463 if (CurLoop
->getParentLoop()) {
464 CurLoop
->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock
, *LI
);
465 CurLoop
->getParentLoop()->addBasicBlockToLoop(MemCheckBlock
, *LI
);
466 CurLoop
->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock
, *LI
);
467 CurLoop
->getParentLoop()->addChildLoop(SVELoop
);
468 CurLoop
->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock
, *LI
);
469 CurLoop
->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock
, *LI
);
470 CurLoop
->getParentLoop()->addChildLoop(ScalarLoop
);
472 LI
->addTopLevelLoop(SVELoop
);
473 LI
->addTopLevelLoop(ScalarLoop
);
476 // Add the new basic blocks to their associated loops.
477 SVELoop
->addBasicBlockToLoop(SVELoopStartBlock
, *LI
);
478 SVELoop
->addBasicBlockToLoop(SVELoopIncBlock
, *LI
);
480 ScalarLoop
->addBasicBlockToLoop(LoopStartBlock
, *LI
);
481 ScalarLoop
->addBasicBlockToLoop(LoopIncBlock
, *LI
);
483 // Set up some types and constants that we intend to reuse.
484 Type
*I64Type
= Builder
.getInt64Ty();
486 // Check the zero-extended iteration count > 0
487 Builder
.SetInsertPoint(MinItCheckBlock
);
488 Value
*ExtStart
= Builder
.CreateZExt(Start
, I64Type
);
489 Value
*ExtEnd
= Builder
.CreateZExt(MaxLen
, I64Type
);
490 // This check doesn't really cost us very much.
492 Value
*LimitCheck
= Builder
.CreateICmpULE(Start
, MaxLen
);
493 BranchInst
*MinItCheckBr
=
494 BranchInst::Create(MemCheckBlock
, LoopPreHeaderBlock
, LimitCheck
);
495 MinItCheckBr
->setMetadata(
496 LLVMContext::MD_prof
,
497 MDBuilder(MinItCheckBr
->getContext()).createBranchWeights(99, 1));
498 Builder
.Insert(MinItCheckBr
);
501 {{DominatorTree::Insert
, MinItCheckBlock
, MemCheckBlock
},
502 {DominatorTree::Insert
, MinItCheckBlock
, LoopPreHeaderBlock
}});
504 // For each of the arrays, check the start/end addresses are on the same
506 Builder
.SetInsertPoint(MemCheckBlock
);
508 // The early exit in the original loop means that when performing vector
509 // loads we are potentially reading ahead of the early exit. So we could
510 // fault if crossing a page boundary. Therefore, we create runtime memory
511 // checks based on the minimum page size as follows:
512 // 1. Calculate the addresses of the first memory accesses in the loop,
513 // i.e. LhsStart and RhsStart.
514 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
515 // 3. Determine which pages correspond to all the memory accesses, i.e
516 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
517 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
518 // we know we won't cross any page boundaries in the loop so we can
519 // enter the vector loop! Otherwise we fall back on the scalar loop.
520 Value
*LhsStartGEP
= Builder
.CreateGEP(LoadType
, PtrA
, ExtStart
);
521 Value
*RhsStartGEP
= Builder
.CreateGEP(LoadType
, PtrB
, ExtStart
);
522 Value
*RhsStart
= Builder
.CreatePtrToInt(RhsStartGEP
, I64Type
);
523 Value
*LhsStart
= Builder
.CreatePtrToInt(LhsStartGEP
, I64Type
);
524 Value
*LhsEndGEP
= Builder
.CreateGEP(LoadType
, PtrA
, ExtEnd
);
525 Value
*RhsEndGEP
= Builder
.CreateGEP(LoadType
, PtrB
, ExtEnd
);
526 Value
*LhsEnd
= Builder
.CreatePtrToInt(LhsEndGEP
, I64Type
);
527 Value
*RhsEnd
= Builder
.CreatePtrToInt(RhsEndGEP
, I64Type
);
529 const uint64_t MinPageSize
= TTI
->getMinPageSize().value();
530 const uint64_t AddrShiftAmt
= llvm::Log2_64(MinPageSize
);
531 Value
*LhsStartPage
= Builder
.CreateLShr(LhsStart
, AddrShiftAmt
);
532 Value
*LhsEndPage
= Builder
.CreateLShr(LhsEnd
, AddrShiftAmt
);
533 Value
*RhsStartPage
= Builder
.CreateLShr(RhsStart
, AddrShiftAmt
);
534 Value
*RhsEndPage
= Builder
.CreateLShr(RhsEnd
, AddrShiftAmt
);
535 Value
*LhsPageCmp
= Builder
.CreateICmpNE(LhsStartPage
, LhsEndPage
);
536 Value
*RhsPageCmp
= Builder
.CreateICmpNE(RhsStartPage
, RhsEndPage
);
538 Value
*CombinedPageCmp
= Builder
.CreateOr(LhsPageCmp
, RhsPageCmp
);
539 BranchInst
*CombinedPageCmpCmpBr
= BranchInst::Create(
540 LoopPreHeaderBlock
, SVELoopPreheaderBlock
, CombinedPageCmp
);
541 CombinedPageCmpCmpBr
->setMetadata(
542 LLVMContext::MD_prof
, MDBuilder(CombinedPageCmpCmpBr
->getContext())
543 .createBranchWeights(10, 90));
544 Builder
.Insert(CombinedPageCmpCmpBr
);
547 {{DominatorTree::Insert
, MemCheckBlock
, LoopPreHeaderBlock
},
548 {DominatorTree::Insert
, MemCheckBlock
, SVELoopPreheaderBlock
}});
550 // Set up the SVE loop preheader, i.e. calculate initial loop predicate,
551 // zero-extend MaxLen to 64-bits, determine the number of vector elements
552 // processed in each iteration, etc.
553 Builder
.SetInsertPoint(SVELoopPreheaderBlock
);
555 // At this point we know two things must be true:
557 // 2. ExtMaxLen <= MinPageSize due to the page checks.
558 // Therefore, we know that we can use a 64-bit induction variable that
559 // starts from 0 -> ExtMaxLen and it will not overflow.
560 ScalableVectorType
*PredVTy
=
561 ScalableVectorType::get(Builder
.getInt1Ty(), 16);
563 Value
*InitialPred
= Builder
.CreateIntrinsic(
564 Intrinsic::get_active_lane_mask
, {PredVTy
, I64Type
}, {ExtStart
, ExtEnd
});
566 Value
*VecLen
= Builder
.CreateIntrinsic(Intrinsic::vscale
, {I64Type
}, {});
567 VecLen
= Builder
.CreateMul(VecLen
, ConstantInt::get(I64Type
, 16), "",
568 /*HasNUW=*/true, /*HasNSW=*/true);
570 Value
*PFalse
= Builder
.CreateVectorSplat(PredVTy
->getElementCount(),
571 Builder
.getInt1(false));
573 BranchInst
*JumpToSVELoop
= BranchInst::Create(SVELoopStartBlock
);
574 Builder
.Insert(JumpToSVELoop
);
577 {{DominatorTree::Insert
, SVELoopPreheaderBlock
, SVELoopStartBlock
}});
579 // Set up the first SVE loop block by creating the PHIs, doing the vector
580 // loads and comparing the vectors.
581 Builder
.SetInsertPoint(SVELoopStartBlock
);
582 PHINode
*LoopPred
= Builder
.CreatePHI(PredVTy
, 2, "mismatch_sve_loop_pred");
583 LoopPred
->addIncoming(InitialPred
, SVELoopPreheaderBlock
);
584 PHINode
*SVEIndexPhi
= Builder
.CreatePHI(I64Type
, 2, "mismatch_sve_index");
585 SVEIndexPhi
->addIncoming(ExtStart
, SVELoopPreheaderBlock
);
586 Type
*SVELoadType
= ScalableVectorType::get(Builder
.getInt8Ty(), 16);
587 Value
*Passthru
= ConstantInt::getNullValue(SVELoadType
);
589 Value
*SVELhsGep
= Builder
.CreateGEP(LoadType
, PtrA
, SVEIndexPhi
);
590 if (GEPA
->isInBounds())
591 cast
<GetElementPtrInst
>(SVELhsGep
)->setIsInBounds(true);
592 Value
*SVELhsLoad
= Builder
.CreateMaskedLoad(SVELoadType
, SVELhsGep
, Align(1),
595 Value
*SVERhsGep
= Builder
.CreateGEP(LoadType
, PtrB
, SVEIndexPhi
);
596 if (GEPB
->isInBounds())
597 cast
<GetElementPtrInst
>(SVERhsGep
)->setIsInBounds(true);
598 Value
*SVERhsLoad
= Builder
.CreateMaskedLoad(SVELoadType
, SVERhsGep
, Align(1),
601 Value
*SVEMatchCmp
= Builder
.CreateICmpNE(SVELhsLoad
, SVERhsLoad
);
602 SVEMatchCmp
= Builder
.CreateSelect(LoopPred
, SVEMatchCmp
, PFalse
);
603 Value
*SVEMatchHasActiveLanes
= Builder
.CreateOrReduce(SVEMatchCmp
);
604 BranchInst
*SVEEarlyExit
= BranchInst::Create(
605 SVELoopMismatchBlock
, SVELoopIncBlock
, SVEMatchHasActiveLanes
);
606 Builder
.Insert(SVEEarlyExit
);
609 {{DominatorTree::Insert
, SVELoopStartBlock
, SVELoopMismatchBlock
},
610 {DominatorTree::Insert
, SVELoopStartBlock
, SVELoopIncBlock
}});
612 // Increment the index counter and calculate the predicate for the next
613 // iteration of the loop. We branch back to the start of the loop if there
614 // is at least one active lane.
615 Builder
.SetInsertPoint(SVELoopIncBlock
);
616 Value
*NewSVEIndexPhi
= Builder
.CreateAdd(SVEIndexPhi
, VecLen
, "",
617 /*HasNUW=*/true, /*HasNSW=*/true);
618 SVEIndexPhi
->addIncoming(NewSVEIndexPhi
, SVELoopIncBlock
);
620 Builder
.CreateIntrinsic(Intrinsic::get_active_lane_mask
,
621 {PredVTy
, I64Type
}, {NewSVEIndexPhi
, ExtEnd
});
622 LoopPred
->addIncoming(NewPred
, SVELoopIncBlock
);
624 Value
*PredHasActiveLanes
=
625 Builder
.CreateExtractElement(NewPred
, uint64_t(0));
626 BranchInst
*SVELoopBranchBack
=
627 BranchInst::Create(SVELoopStartBlock
, EndBlock
, PredHasActiveLanes
);
628 Builder
.Insert(SVELoopBranchBack
);
630 DTU
.applyUpdates({{DominatorTree::Insert
, SVELoopIncBlock
, SVELoopStartBlock
},
631 {DominatorTree::Insert
, SVELoopIncBlock
, EndBlock
}});
633 // If we found a mismatch then we need to calculate which lane in the vector
634 // had a mismatch and add that on to the current loop index.
635 Builder
.SetInsertPoint(SVELoopMismatchBlock
);
636 PHINode
*FoundPred
= Builder
.CreatePHI(PredVTy
, 1, "mismatch_sve_found_pred");
637 FoundPred
->addIncoming(SVEMatchCmp
, SVELoopStartBlock
);
638 PHINode
*LastLoopPred
=
639 Builder
.CreatePHI(PredVTy
, 1, "mismatch_sve_last_loop_pred");
640 LastLoopPred
->addIncoming(LoopPred
, SVELoopStartBlock
);
641 PHINode
*SVEFoundIndex
=
642 Builder
.CreatePHI(I64Type
, 1, "mismatch_sve_found_index");
643 SVEFoundIndex
->addIncoming(SVEIndexPhi
, SVELoopStartBlock
);
645 Value
*PredMatchCmp
= Builder
.CreateAnd(LastLoopPred
, FoundPred
);
646 Value
*Ctz
= Builder
.CreateIntrinsic(
647 Intrinsic::experimental_cttz_elts
, {ResType
, PredMatchCmp
->getType()},
648 {PredMatchCmp
, /*ZeroIsPoison=*/Builder
.getInt1(true)});
649 Ctz
= Builder
.CreateZExt(Ctz
, I64Type
);
650 Value
*SVELoopRes64
= Builder
.CreateAdd(SVEFoundIndex
, Ctz
, "",
651 /*HasNUW=*/true, /*HasNSW=*/true);
652 Value
*SVELoopRes
= Builder
.CreateTrunc(SVELoopRes64
, ResType
);
654 Builder
.Insert(BranchInst::Create(EndBlock
));
656 DTU
.applyUpdates({{DominatorTree::Insert
, SVELoopMismatchBlock
, EndBlock
}});
658 // Generate code for scalar loop.
659 Builder
.SetInsertPoint(LoopPreHeaderBlock
);
660 Builder
.Insert(BranchInst::Create(LoopStartBlock
));
663 {{DominatorTree::Insert
, LoopPreHeaderBlock
, LoopStartBlock
}});
665 Builder
.SetInsertPoint(LoopStartBlock
);
666 PHINode
*IndexPhi
= Builder
.CreatePHI(ResType
, 2, "mismatch_index");
667 IndexPhi
->addIncoming(Start
, LoopPreHeaderBlock
);
669 // Otherwise compare the values
670 // Load bytes from each array and compare them.
671 Value
*GepOffset
= Builder
.CreateZExt(IndexPhi
, I64Type
);
673 Value
*LhsGep
= Builder
.CreateGEP(LoadType
, PtrA
, GepOffset
);
674 if (GEPA
->isInBounds())
675 cast
<GetElementPtrInst
>(LhsGep
)->setIsInBounds(true);
676 Value
*LhsLoad
= Builder
.CreateLoad(LoadType
, LhsGep
);
678 Value
*RhsGep
= Builder
.CreateGEP(LoadType
, PtrB
, GepOffset
);
679 if (GEPB
->isInBounds())
680 cast
<GetElementPtrInst
>(RhsGep
)->setIsInBounds(true);
681 Value
*RhsLoad
= Builder
.CreateLoad(LoadType
, RhsGep
);
683 Value
*MatchCmp
= Builder
.CreateICmpEQ(LhsLoad
, RhsLoad
);
684 // If we have a mismatch then exit the loop ...
685 BranchInst
*MatchCmpBr
= BranchInst::Create(LoopIncBlock
, EndBlock
, MatchCmp
);
686 Builder
.Insert(MatchCmpBr
);
688 DTU
.applyUpdates({{DominatorTree::Insert
, LoopStartBlock
, LoopIncBlock
},
689 {DominatorTree::Insert
, LoopStartBlock
, EndBlock
}});
691 // Have we reached the maximum permitted length for the loop?
692 Builder
.SetInsertPoint(LoopIncBlock
);
693 Value
*PhiInc
= Builder
.CreateAdd(IndexPhi
, ConstantInt::get(ResType
, 1), "",
694 /*HasNUW=*/Index
->hasNoUnsignedWrap(),
695 /*HasNSW=*/Index
->hasNoSignedWrap());
696 IndexPhi
->addIncoming(PhiInc
, LoopIncBlock
);
697 Value
*IVCmp
= Builder
.CreateICmpEQ(PhiInc
, MaxLen
);
698 BranchInst
*IVCmpBr
= BranchInst::Create(EndBlock
, LoopStartBlock
, IVCmp
);
699 Builder
.Insert(IVCmpBr
);
701 DTU
.applyUpdates({{DominatorTree::Insert
, LoopIncBlock
, EndBlock
},
702 {DominatorTree::Insert
, LoopIncBlock
, LoopStartBlock
}});
704 // In the end block we need to insert a PHI node to deal with three cases:
705 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
706 // 2. We exitted the scalar loop early due to a mismatch and need to return
707 // the index that we found.
708 // 3. We didn't find a mismatch in the SVE loop, so we return MaxLen.
709 // 4. We exitted the SVE loop early due to a mismatch and need to return
710 // the index that we found.
711 Builder
.SetInsertPoint(EndBlock
, EndBlock
->getFirstInsertionPt());
712 PHINode
*ResPhi
= Builder
.CreatePHI(ResType
, 4, "mismatch_result");
713 ResPhi
->addIncoming(MaxLen
, LoopIncBlock
);
714 ResPhi
->addIncoming(IndexPhi
, LoopStartBlock
);
715 ResPhi
->addIncoming(MaxLen
, SVELoopIncBlock
);
716 ResPhi
->addIncoming(SVELoopRes
, SVELoopMismatchBlock
);
718 Value
*FinalRes
= Builder
.CreateTrunc(ResPhi
, ResType
);
721 ScalarLoop
->verifyLoop();
722 SVELoop
->verifyLoop();
723 if (!SVELoop
->isRecursivelyLCSSAForm(*DT
, *LI
))
724 report_fatal_error("Loops must remain in LCSSA form!");
725 if (!ScalarLoop
->isRecursivelyLCSSAForm(*DT
, *LI
))
726 report_fatal_error("Loops must remain in LCSSA form!");
732 void AArch64LoopIdiomTransform::transformByteCompare(
733 GetElementPtrInst
*GEPA
, GetElementPtrInst
*GEPB
, PHINode
*IndPhi
,
734 Value
*MaxLen
, Instruction
*Index
, Value
*Start
, bool IncIdx
,
735 BasicBlock
*FoundBB
, BasicBlock
*EndBB
) {
737 // Insert the byte compare code at the end of the preheader block
738 BasicBlock
*Preheader
= CurLoop
->getLoopPreheader();
739 BasicBlock
*Header
= CurLoop
->getHeader();
740 BranchInst
*PHBranch
= cast
<BranchInst
>(Preheader
->getTerminator());
741 IRBuilder
<> Builder(PHBranch
);
742 DomTreeUpdater
DTU(DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
743 Builder
.SetCurrentDebugLocation(PHBranch
->getDebugLoc());
745 // Increment the pointer if this was done before the loads in the loop.
747 Start
= Builder
.CreateAdd(Start
, ConstantInt::get(Start
->getType(), 1));
750 expandFindMismatch(Builder
, DTU
, GEPA
, GEPB
, Index
, Start
, MaxLen
);
752 // Replaces uses of index & induction Phi with intrinsic (we already
753 // checked that the the first instruction of Header is the Phi above).
754 assert(IndPhi
->hasOneUse() && "Index phi node has more than one use!");
755 Index
->replaceAllUsesWith(ByteCmpRes
);
757 assert(PHBranch
->isUnconditional() &&
758 "Expected preheader to terminate with an unconditional branch.");
760 // If no mismatch was found, we can jump to the end block. Create a
761 // new basic block for the compare instruction.
762 auto *CmpBB
= BasicBlock::Create(Preheader
->getContext(), "byte.compare",
763 Preheader
->getParent());
764 CmpBB
->moveBefore(EndBB
);
766 // Replace the branch in the preheader with an always-true conditional branch.
767 // This ensures there is still a reference to the original loop.
768 Builder
.CreateCondBr(Builder
.getTrue(), CmpBB
, Header
);
769 PHBranch
->eraseFromParent();
771 BasicBlock
*MismatchEnd
= cast
<Instruction
>(ByteCmpRes
)->getParent();
772 DTU
.applyUpdates({{DominatorTree::Insert
, MismatchEnd
, CmpBB
}});
774 // Create the branch to either the end or found block depending on the value
775 // returned by the intrinsic.
776 Builder
.SetInsertPoint(CmpBB
);
777 if (FoundBB
!= EndBB
) {
778 Value
*FoundCmp
= Builder
.CreateICmpEQ(ByteCmpRes
, MaxLen
);
779 Builder
.CreateCondBr(FoundCmp
, EndBB
, FoundBB
);
780 DTU
.applyUpdates({{DominatorTree::Insert
, CmpBB
, FoundBB
},
781 {DominatorTree::Insert
, CmpBB
, EndBB
}});
784 Builder
.CreateBr(FoundBB
);
785 DTU
.applyUpdates({{DominatorTree::Insert
, CmpBB
, FoundBB
}});
788 auto fixSuccessorPhis
= [&](BasicBlock
*SuccBB
) {
789 for (PHINode
&PN
: SuccBB
->phis()) {
790 // At this point we've already replaced all uses of the result from the
791 // loop with ByteCmp. Look through the incoming values to find ByteCmp,
792 // meaning this is a Phi collecting the results of the byte compare.
794 for (Value
*Op
: PN
.incoming_values())
795 if (Op
== ByteCmpRes
) {
800 // Any PHI that depended upon the result of the byte compare needs a new
801 // incoming value from CmpBB. This is because the original loop will get
804 PN
.addIncoming(ByteCmpRes
, CmpBB
);
806 // There should be no other outside uses of other values in the
807 // original loop. Any incoming values should either:
808 // 1. Be for blocks outside the loop, which aren't interesting. Or ..
809 // 2. These are from blocks in the loop with values defined outside
810 // the loop. We should a similar incoming value from CmpBB.
811 for (BasicBlock
*BB
: PN
.blocks())
812 if (CurLoop
->contains(BB
)) {
813 PN
.addIncoming(PN
.getIncomingValueForBlock(BB
), CmpBB
);
820 // Ensure all Phis in the successors of CmpBB have an incoming value from it.
821 fixSuccessorPhis(EndBB
);
822 if (EndBB
!= FoundBB
)
823 fixSuccessorPhis(FoundBB
);
825 // The new CmpBB block isn't part of the loop, but will need to be added to
826 // the outer loop if there is one.
827 if (!CurLoop
->isOutermost())
828 CurLoop
->getParentLoop()->addBasicBlockToLoop(CmpBB
, *LI
);
830 if (VerifyLoops
&& CurLoop
->getParentLoop()) {
831 CurLoop
->getParentLoop()->verifyLoop();
832 if (!CurLoop
->getParentLoop()->isRecursivelyLCSSAForm(*DT
, *LI
))
833 report_fatal_error("Loops must remain in LCSSA form!");