1 //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
11 /// branches to help accelerate DSP applications. These two extensions can be
12 /// combined to provide implicit vector predication within a low-overhead loop.
13 /// The HardwareLoops pass inserts intrinsics identifying loops that the
14 /// backend will attempt to convert into a low-overhead loop. The vectorizer is
15 /// responsible for generating a vectorized loop in which the lanes are
16 /// predicated upon the iteration counter. This pass looks at these predicated
17 /// vector loops, that are targets for low-overhead loops, and prepares it for
18 /// code generation. Once the vectorizer has produced a masked loop, there's a
19 /// couple of final forms:
20 /// - A tail-predicated loop, with implicit predication.
21 /// - A loop containing multiple VCPT instructions, predicating multiple VPT
22 /// blocks of instructions operating on different vector types.
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/Analysis/LoopPass.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpander.h"
28 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/CodeGen/TargetPassConfig.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IRBuilder.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include "ARMSubtarget.h"
41 #define DEBUG_TYPE "mve-tail-predication"
42 #define DESC "Transform predicated vector loops to use MVE tail predication"
45 DisableTailPredication("disable-mve-tail-predication", cl::Hidden
,
47 cl::desc("Disable MVE Tail Predication"));
50 class MVETailPredication
: public LoopPass
{
51 SmallVector
<IntrinsicInst
*, 4> MaskedInsts
;
53 ScalarEvolution
*SE
= nullptr;
54 TargetTransformInfo
*TTI
= nullptr;
59 MVETailPredication() : LoopPass(ID
) { }
61 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
62 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
63 AU
.addRequired
<LoopInfoWrapperPass
>();
64 AU
.addRequired
<TargetPassConfig
>();
65 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
66 AU
.addPreserved
<LoopInfoWrapperPass
>();
70 bool runOnLoop(Loop
*L
, LPPassManager
&) override
;
74 /// Perform the relevant checks on the loop and convert if possible.
75 bool TryConvert(Value
*TripCount
);
77 /// Return whether this is a vectorized loop, that contains masked
79 bool IsPredicatedVectorLoop();
81 /// Compute a value for the total number of elements that the predicated
82 /// loop will process.
83 Value
*ComputeElements(Value
*TripCount
, VectorType
*VecTy
);
85 /// Is the icmp that generates an i1 vector, based upon a loop counter
86 /// and a limit that is defined outside the loop.
87 bool isTailPredicate(Instruction
*Predicate
, Value
*NumElements
);
92 static bool IsDecrement(Instruction
&I
) {
93 auto *Call
= dyn_cast
<IntrinsicInst
>(&I
);
97 Intrinsic::ID ID
= Call
->getIntrinsicID();
98 return ID
== Intrinsic::loop_decrement_reg
;
101 static bool IsMasked(Instruction
*I
) {
102 auto *Call
= dyn_cast
<IntrinsicInst
>(I
);
106 Intrinsic::ID ID
= Call
->getIntrinsicID();
107 // TODO: Support gather/scatter expand/compress operations.
108 return ID
== Intrinsic::masked_store
|| ID
== Intrinsic::masked_load
;
111 bool MVETailPredication::runOnLoop(Loop
*L
, LPPassManager
&) {
112 if (skipLoop(L
) || DisableTailPredication
)
115 Function
&F
= *L
->getHeader()->getParent();
116 auto &TPC
= getAnalysis
<TargetPassConfig
>();
117 auto &TM
= TPC
.getTM
<TargetMachine
>();
118 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
119 TTI
= &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
120 SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
123 // The MVE and LOB extensions are combined to enable tail-predication, but
124 // there's nothing preventing us from generating VCTP instructions for v8.1m.
125 if (!ST
->hasMVEIntegerOps() || !ST
->hasV8_1MMainlineOps()) {
126 LLVM_DEBUG(dbgs() << "TP: Not a v8.1m.main+mve target.\n");
130 BasicBlock
*Preheader
= L
->getLoopPreheader();
134 auto FindLoopIterations
= [](BasicBlock
*BB
) -> IntrinsicInst
* {
135 for (auto &I
: *BB
) {
136 auto *Call
= dyn_cast
<IntrinsicInst
>(&I
);
140 Intrinsic::ID ID
= Call
->getIntrinsicID();
141 if (ID
== Intrinsic::set_loop_iterations
||
142 ID
== Intrinsic::test_set_loop_iterations
)
143 return cast
<IntrinsicInst
>(&I
);
148 // Look for the hardware loop intrinsic that sets the iteration count.
149 IntrinsicInst
*Setup
= FindLoopIterations(Preheader
);
151 // The test.set iteration could live in the pre- preheader.
153 if (!Preheader
->getSinglePredecessor())
155 Setup
= FindLoopIterations(Preheader
->getSinglePredecessor());
160 // Search for the hardware loop intrinic that decrements the loop counter.
161 IntrinsicInst
*Decrement
= nullptr;
162 for (auto *BB
: L
->getBlocks()) {
163 for (auto &I
: *BB
) {
164 if (IsDecrement(I
)) {
165 Decrement
= cast
<IntrinsicInst
>(&I
);
174 LLVM_DEBUG(dbgs() << "TP: Running on Loop: " << *L
176 << *Decrement
<< "\n");
177 bool Changed
= TryConvert(Setup
->getArgOperand(0));
181 bool MVETailPredication::isTailPredicate(Instruction
*I
, Value
*NumElements
) {
182 // Look for the following:
184 // %trip.count.minus.1 = add i32 %N, -1
185 // %broadcast.splatinsert10 = insertelement <4 x i32> undef,
186 // i32 %trip.count.minus.1, i32 0
187 // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10,
189 // <4 x i32> zeroinitializer
193 // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0
194 // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert,
196 // <4 x i32> zeroinitializer
197 // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3>
198 // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11
200 // And return whether V == %pred.
202 using namespace PatternMatch
;
204 CmpInst::Predicate Pred
;
205 Instruction
*Shuffle
= nullptr;
206 Instruction
*Induction
= nullptr;
209 if (!match(I
, m_ICmp(Pred
, m_Instruction(Induction
),
210 m_Instruction(Shuffle
))) ||
211 Pred
!= ICmpInst::ICMP_ULE
|| !L
->isLoopInvariant(Shuffle
))
214 // First find the stuff outside the loop which is setting up the limit
216 // The invariant shuffle that broadcast the limit into a vector.
217 Instruction
*Insert
= nullptr;
218 if (!match(Shuffle
, m_ShuffleVector(m_Instruction(Insert
), m_Undef(),
222 // Insert the limit into a vector.
223 Instruction
*BECount
= nullptr;
224 if (!match(Insert
, m_InsertElement(m_Undef(), m_Instruction(BECount
),
228 // The limit calculation, backedge count.
229 Value
*TripCount
= nullptr;
230 if (!match(BECount
, m_Add(m_Value(TripCount
), m_AllOnes())))
233 if (TripCount
!= NumElements
)
236 // Now back to searching inside the loop body...
237 // Find the add with takes the index iv and adds a constant vector to it.
238 Instruction
*BroadcastSplat
= nullptr;
239 Constant
*Const
= nullptr;
240 if (!match(Induction
, m_Add(m_Instruction(BroadcastSplat
),
244 // Check that we're adding <0, 1, 2, 3...
245 if (auto *CDS
= dyn_cast
<ConstantDataSequential
>(Const
)) {
246 for (unsigned i
= 0; i
< CDS
->getNumElements(); ++i
) {
247 if (CDS
->getElementAsInteger(i
) != i
)
253 // The shuffle which broadcasts the index iv into a vector.
254 if (!match(BroadcastSplat
, m_ShuffleVector(m_Instruction(Insert
), m_Undef(),
258 // The insert element which initialises a vector with the index iv.
259 Instruction
*IV
= nullptr;
260 if (!match(Insert
, m_InsertElement(m_Undef(), m_Instruction(IV
), m_Zero())))
264 auto *Phi
= dyn_cast
<PHINode
>(IV
);
268 // TODO: Don't think we need to check the entry value.
269 Value
*OnEntry
= Phi
->getIncomingValueForBlock(L
->getLoopPreheader());
270 if (!match(OnEntry
, m_Zero()))
273 Value
*InLoop
= Phi
->getIncomingValueForBlock(L
->getLoopLatch());
274 unsigned Lanes
= cast
<VectorType
>(Insert
->getType())->getNumElements();
276 Instruction
*LHS
= nullptr;
277 if (!match(InLoop
, m_Add(m_Instruction(LHS
), m_SpecificInt(Lanes
))))
283 static VectorType
* getVectorType(IntrinsicInst
*I
) {
284 unsigned TypeOp
= I
->getIntrinsicID() == Intrinsic::masked_load
? 0 : 1;
285 auto *PtrTy
= cast
<PointerType
>(I
->getOperand(TypeOp
)->getType());
286 return cast
<VectorType
>(PtrTy
->getElementType());
289 bool MVETailPredication::IsPredicatedVectorLoop() {
290 // Check that the loop contains at least one masked load/store intrinsic.
291 // We only support 'normal' vector instructions - other than masked
293 for (auto *BB
: L
->getBlocks()) {
294 for (auto &I
: *BB
) {
296 VectorType
*VecTy
= getVectorType(cast
<IntrinsicInst
>(&I
));
297 unsigned Lanes
= VecTy
->getNumElements();
298 unsigned ElementWidth
= VecTy
->getScalarSizeInBits();
299 // MVE vectors are 128-bit, but don't support 128 x i1.
300 // TODO: Can we support vectors larger than 128-bits?
301 unsigned MaxWidth
= TTI
->getRegisterBitWidth(true);
302 if (Lanes
* ElementWidth
!= MaxWidth
|| Lanes
== MaxWidth
)
304 MaskedInsts
.push_back(cast
<IntrinsicInst
>(&I
));
305 } else if (auto *Int
= dyn_cast
<IntrinsicInst
>(&I
)) {
306 for (auto &U
: Int
->args()) {
307 if (isa
<VectorType
>(U
->getType()))
314 return !MaskedInsts
.empty();
317 Value
* MVETailPredication::ComputeElements(Value
*TripCount
,
319 const SCEV
*TripCountSE
= SE
->getSCEV(TripCount
);
320 ConstantInt
*VF
= ConstantInt::get(cast
<IntegerType
>(TripCount
->getType()),
321 VecTy
->getNumElements());
323 if (VF
->equalsInt(1))
326 // TODO: Support constant trip counts.
327 auto VisitAdd
= [&](const SCEVAddExpr
*S
) -> const SCEVMulExpr
* {
328 if (auto *Const
= dyn_cast
<SCEVConstant
>(S
->getOperand(0))) {
329 if (Const
->getAPInt() != -VF
->getValue())
333 return dyn_cast
<SCEVMulExpr
>(S
->getOperand(1));
336 auto VisitMul
= [&](const SCEVMulExpr
*S
) -> const SCEVUDivExpr
* {
337 if (auto *Const
= dyn_cast
<SCEVConstant
>(S
->getOperand(0))) {
338 if (Const
->getValue() != VF
)
342 return dyn_cast
<SCEVUDivExpr
>(S
->getOperand(1));
345 auto VisitDiv
= [&](const SCEVUDivExpr
*S
) -> const SCEV
* {
346 if (auto *Const
= dyn_cast
<SCEVConstant
>(S
->getRHS())) {
347 if (Const
->getValue() != VF
)
352 if (auto *RoundUp
= dyn_cast
<SCEVAddExpr
>(S
->getLHS())) {
353 if (auto *Const
= dyn_cast
<SCEVConstant
>(RoundUp
->getOperand(0))) {
354 if (Const
->getAPInt() != (VF
->getValue() - 1))
359 return RoundUp
->getOperand(1);
364 // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to
365 // determine the numbers of elements instead? Looks like this is what is used
366 // for delinearization, but I'm not sure if it can be applied to the
367 // vectorized form - at least not without a bit more work than I feel
370 // Search for Elems in the following SCEV:
371 // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw>
372 const SCEV
*Elems
= nullptr;
373 if (auto *TC
= dyn_cast
<SCEVAddExpr
>(TripCountSE
))
374 if (auto *Div
= dyn_cast
<SCEVUDivExpr
>(TC
->getOperand(1)))
375 if (auto *Add
= dyn_cast
<SCEVAddExpr
>(Div
->getLHS()))
376 if (auto *Mul
= VisitAdd(Add
))
377 if (auto *Div
= VisitMul(Mul
))
378 if (auto *Res
= VisitDiv(Div
))
384 Instruction
*InsertPt
= L
->getLoopPreheader()->getTerminator();
385 if (!isSafeToExpandAt(Elems
, InsertPt
, *SE
))
388 auto DL
= L
->getHeader()->getModule()->getDataLayout();
389 SCEVExpander
Expander(*SE
, DL
, "elements");
390 return Expander
.expandCodeFor(Elems
, Elems
->getType(), InsertPt
);
393 // Look through the exit block to see whether there's a duplicate predicate
394 // instruction. This can happen when we need to perform a select on values
395 // from the last and previous iteration. Instead of doing a straight
396 // replacement of that predicate with the vctp, clone the vctp and place it
397 // in the block. This means that the VPR doesn't have to be live into the
398 // exit block which should make it easier to convert this loop into a proper
399 // tail predicated loop.
400 static void Cleanup(DenseMap
<Instruction
*, Instruction
*> &NewPredicates
,
401 SetVector
<Instruction
*> &MaybeDead
, Loop
*L
) {
402 if (BasicBlock
*Exit
= L
->getUniqueExitBlock()) {
403 for (auto &Pair
: NewPredicates
) {
404 Instruction
*OldPred
= Pair
.first
;
405 Instruction
*NewPred
= Pair
.second
;
407 for (auto &I
: *Exit
) {
408 if (I
.isSameOperationAs(OldPred
)) {
409 Instruction
*PredClone
= NewPred
->clone();
410 PredClone
->insertBefore(&I
);
411 I
.replaceAllUsesWith(PredClone
);
412 MaybeDead
.insert(&I
);
419 // Drop references and add operands to check for dead.
420 SmallPtrSet
<Instruction
*, 4> Dead
;
421 while (!MaybeDead
.empty()) {
422 auto *I
= MaybeDead
.front();
424 if (I
->hasNUsesOrMore(1))
427 for (auto &U
: I
->operands()) {
428 if (auto *OpI
= dyn_cast
<Instruction
>(U
))
429 MaybeDead
.insert(OpI
);
431 I
->dropAllReferences();
436 I
->eraseFromParent();
438 for (auto I
: L
->blocks())
442 bool MVETailPredication::TryConvert(Value
*TripCount
) {
443 if (!IsPredicatedVectorLoop())
446 LLVM_DEBUG(dbgs() << "TP: Found predicated vector loop.\n");
448 // Walk through the masked intrinsics and try to find whether the predicate
449 // operand is generated from an induction variable.
450 Module
*M
= L
->getHeader()->getModule();
451 Type
*Ty
= IntegerType::get(M
->getContext(), 32);
452 SetVector
<Instruction
*> Predicates
;
453 DenseMap
<Instruction
*, Instruction
*> NewPredicates
;
455 for (auto *I
: MaskedInsts
) {
456 Intrinsic::ID ID
= I
->getIntrinsicID();
457 unsigned PredOp
= ID
== Intrinsic::masked_load
? 2 : 3;
458 auto *Predicate
= dyn_cast
<Instruction
>(I
->getArgOperand(PredOp
));
459 if (!Predicate
|| Predicates
.count(Predicate
))
462 VectorType
*VecTy
= getVectorType(I
);
463 Value
*NumElements
= ComputeElements(TripCount
, VecTy
);
467 if (!isTailPredicate(Predicate
, NumElements
)) {
468 LLVM_DEBUG(dbgs() << "TP: Not tail predicate: " << *Predicate
<< "\n");
472 LLVM_DEBUG(dbgs() << "TP: Found tail predicate: " << *Predicate
<< "\n");
473 Predicates
.insert(Predicate
);
475 // Insert a phi to count the number of elements processed by the loop.
476 IRBuilder
<> Builder(L
->getHeader()->getFirstNonPHI());
477 PHINode
*Processed
= Builder
.CreatePHI(Ty
, 2);
478 Processed
->addIncoming(NumElements
, L
->getLoopPreheader());
480 // Insert the intrinsic to represent the effect of tail predication.
481 Builder
.SetInsertPoint(cast
<Instruction
>(Predicate
));
482 ConstantInt
*Factor
=
483 ConstantInt::get(cast
<IntegerType
>(Ty
), VecTy
->getNumElements());
484 Intrinsic::ID VCTPID
;
485 switch (VecTy
->getNumElements()) {
487 llvm_unreachable("unexpected number of lanes");
488 case 2: VCTPID
= Intrinsic::arm_vctp64
; break;
489 case 4: VCTPID
= Intrinsic::arm_vctp32
; break;
490 case 8: VCTPID
= Intrinsic::arm_vctp16
; break;
491 case 16: VCTPID
= Intrinsic::arm_vctp8
; break;
493 Function
*VCTP
= Intrinsic::getDeclaration(M
, VCTPID
);
494 Value
*TailPredicate
= Builder
.CreateCall(VCTP
, Processed
);
495 Predicate
->replaceAllUsesWith(TailPredicate
);
496 NewPredicates
[Predicate
] = cast
<Instruction
>(TailPredicate
);
498 // Add the incoming value to the new phi.
499 // TODO: This add likely already exists in the loop.
500 Value
*Remaining
= Builder
.CreateSub(Processed
, Factor
);
501 Processed
->addIncoming(Remaining
, L
->getLoopLatch());
502 LLVM_DEBUG(dbgs() << "TP: Insert processed elements phi: "
503 << *Processed
<< "\n"
504 << "TP: Inserted VCTP: " << *TailPredicate
<< "\n");
508 Cleanup(NewPredicates
, Predicates
, L
);
512 Pass
*llvm::createMVETailPredicationPass() {
513 return new MVETailPredication();
516 char MVETailPredication::ID
= 0;
518 INITIALIZE_PASS_BEGIN(MVETailPredication
, DEBUG_TYPE
, DESC
, false, false)
519 INITIALIZE_PASS_END(MVETailPredication
, DEBUG_TYPE
, DESC
, false, false)