1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
15 //===----------------------------------------------------------------------===//
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
35 #include "ARMSubtarget.h"
38 using namespace PatternMatch
;
40 #define DEBUG_TYPE "arm-parallel-dsp"
42 STATISTIC(NumSMLAD
, "Number of smlad instructions generated");
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden
, cl::init(false),
46 cl::desc("Disable the ARM Parallel DSP pass"));
53 using OpChainList
= SmallVector
<std::unique_ptr
<OpChain
>, 8>;
54 using ReductionList
= SmallVector
<Reduction
, 8>;
55 using ValueList
= SmallVector
<Value
*, 8>;
56 using MemInstList
= SmallVector
<Instruction
*, 8>;
57 using PMACPair
= std::pair
<BinOpChain
*,BinOpChain
*>;
58 using PMACPairList
= SmallVector
<PMACPair
, 8>;
59 using Instructions
= SmallVector
<Instruction
*,16>;
60 using MemLocList
= SmallVector
<MemoryLocation
, 4>;
65 MemInstList VecLd
; // List of all load instructions.
66 MemLocList MemLocs
; // All memory locations read by this tree.
69 OpChain(Instruction
*I
, ValueList
&vl
) : Root(I
), AllValues(vl
) { }
70 virtual ~OpChain() = default;
72 void SetMemoryLocations() {
73 const auto Size
= LocationSize::unknown();
74 for (auto *V
: AllValues
) {
75 if (auto *I
= dyn_cast
<Instruction
>(V
)) {
76 if (I
->mayWriteToMemory())
78 if (auto *Ld
= dyn_cast
<LoadInst
>(V
))
79 MemLocs
.push_back(MemoryLocation(Ld
->getPointerOperand(), Size
));
84 unsigned size() const { return AllValues
.size(); }
87 // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
88 // 'Reduction' contains the phi-node and accumulator statement from where we
89 // start pattern matching, and 'BinOpChain' the multiplication
90 // instructions that are candidates for parallel execution.
91 struct BinOpChain
: public OpChain
{
92 ValueList LHS
; // List of all (narrow) left hand operands.
93 ValueList RHS
; // List of all (narrow) right hand operands.
94 bool Exchange
= false;
96 BinOpChain(Instruction
*I
, ValueList
&lhs
, ValueList
&rhs
) :
97 OpChain(I
, lhs
), LHS(lhs
), RHS(rhs
) {
99 AllValues
.push_back(V
);
102 bool AreSymmetrical(BinOpChain
*Other
);
106 PHINode
*Phi
; // The Phi-node from where we start
108 Instruction
*AccIntAdd
; // The accumulating integer add statement,
109 // i.e, the reduction statement.
110 OpChainList MACCandidates
; // The MAC candidates associated with
111 // this reduction statement.
112 PMACPairList PMACPairs
;
113 Reduction (PHINode
*P
, Instruction
*Acc
) : Phi(P
), AccIntAdd(Acc
) { };
116 class ARMParallelDSP
: public LoopPass
{
119 TargetLibraryInfo
*TLI
;
123 const DataLayout
*DL
;
125 std::map
<LoadInst
*, LoadInst
*> LoadPairs
;
126 std::map
<LoadInst
*, SmallVector
<LoadInst
*, 4>> SequentialLoads
;
128 bool RecordSequentialLoads(BasicBlock
*Header
);
129 bool InsertParallelMACs(Reduction
&Reduction
);
130 bool AreSequentialLoads(LoadInst
*Ld0
, LoadInst
*Ld1
, MemInstList
&VecMem
);
131 void CreateParallelMACPairs(Reduction
&R
);
132 Instruction
*CreateSMLADCall(LoadInst
*VecLd0
, LoadInst
*VecLd1
,
133 Instruction
*Acc
, bool Exchange
,
134 Instruction
*InsertAfter
);
136 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
137 /// Dual performs two signed 16x16-bit multiplications. It adds the
138 /// products to a 32-bit accumulate operand. Optionally, the instruction can
139 /// exchange the halfwords of the second operand before performing the
141 bool MatchSMLAD(Function
&F
);
146 ARMParallelDSP() : LoopPass(ID
) { }
148 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
149 LoopPass::getAnalysisUsage(AU
);
150 AU
.addRequired
<AssumptionCacheTracker
>();
151 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
152 AU
.addRequired
<AAResultsWrapperPass
>();
153 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
154 AU
.addRequired
<LoopInfoWrapperPass
>();
155 AU
.addRequired
<DominatorTreeWrapperPass
>();
156 AU
.addRequired
<TargetPassConfig
>();
157 AU
.addPreserved
<LoopInfoWrapperPass
>();
158 AU
.setPreservesCFG();
161 bool runOnLoop(Loop
*TheLoop
, LPPassManager
&) override
{
162 if (DisableParallelDSP
)
165 SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
166 AA
= &getAnalysis
<AAResultsWrapperPass
>().getAAResults();
167 TLI
= &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI();
168 DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
169 LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
170 auto &TPC
= getAnalysis
<TargetPassConfig
>();
172 BasicBlock
*Header
= TheLoop
->getHeader();
176 // TODO: We assume the loop header and latch to be the same block.
177 // This is not a fundamental restriction, but lifting this would just
178 // require more work to do the transformation and then patch up the CFG.
179 if (Header
!= TheLoop
->getLoopLatch()) {
180 LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
181 "running pass ARMParallelDSP\n");
185 Function
&F
= *Header
->getParent();
187 DL
= &M
->getDataLayout();
189 auto &TM
= TPC
.getTM
<TargetMachine
>();
190 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
192 if (!ST
->allowsUnalignedMem()) {
193 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
194 "running pass ARMParallelDSP\n");
199 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
204 LoopAccessInfo
LAI(L
, SE
, TLI
, AA
, DT
, LI
);
205 bool Changes
= false;
207 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
208 LLVM_DEBUG(dbgs() << " - " << F
.getName() << "\n\n");
210 if (!RecordSequentialLoads(Header
)) {
211 LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
215 Changes
= MatchSMLAD(F
);
221 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
222 // instructions, which is set to 16. So here we should collect all i8 and i16
223 // narrow operations.
224 // TODO: we currently only collect i16, and will support i8 later, so that's
225 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
226 template<unsigned MaxBitWidth
>
227 static bool IsNarrowSequence(Value
*V
, ValueList
&VL
) {
228 LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V
->dump());
231 if (match(V
, m_ConstantInt(CInt
))) {
232 // TODO: if a constant is used, it needs to fit within the bit width.
236 auto *I
= dyn_cast
<Instruction
>(V
);
240 Value
*Val
, *LHS
, *RHS
;
241 if (match(V
, m_Trunc(m_Value(Val
)))) {
242 if (cast
<TruncInst
>(I
)->getDestTy()->getIntegerBitWidth() == MaxBitWidth
)
243 return IsNarrowSequence
<MaxBitWidth
>(Val
, VL
);
244 } else if (match(V
, m_Add(m_Value(LHS
), m_Value(RHS
)))) {
245 // TODO: we need to implement sadd16/sadd8 for this, which enables to
246 // also do the rewrite for smlad8.ll, but it is unsupported for now.
247 LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I
->dump());
249 } else if (match(V
, m_ZExtOrSExt(m_Value(Val
)))) {
250 if (cast
<CastInst
>(I
)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth
) {
251 LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
252 cast
<CastInst
>(I
)->getSrcTy()->getIntegerBitWidth() << "\n");
256 if (match(Val
, m_Load(m_Value()))) {
257 LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val
->dump());
263 LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I
->dump());
267 template<typename MemInst
>
268 static bool AreSequentialAccesses(MemInst
*MemOp0
, MemInst
*MemOp1
,
269 const DataLayout
&DL
, ScalarEvolution
&SE
) {
270 if (!MemOp0
->isSimple() || !MemOp1
->isSimple()) {
271 LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
274 if (isConsecutiveAccess(MemOp0
, MemOp1
, DL
, SE
)) {
275 LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
278 LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
282 bool ARMParallelDSP::AreSequentialLoads(LoadInst
*Ld0
, LoadInst
*Ld1
,
283 MemInstList
&VecMem
) {
287 LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
288 dbgs() << "Ld0:"; Ld0
->dump();
289 dbgs() << "Ld1:"; Ld1
->dump();
292 if (!Ld0
->hasOneUse() || !Ld1
->hasOneUse()) {
293 LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
297 if (!LoadPairs
.count(Ld0
) || LoadPairs
[Ld0
] != Ld1
)
301 VecMem
.push_back(Ld0
);
302 VecMem
.push_back(Ld1
);
306 /// Iterate through the block and record base, offset pairs of loads as well as
307 /// maximal sequences of sequential loads.
308 bool ARMParallelDSP::RecordSequentialLoads(BasicBlock
*Header
) {
309 SmallVector
<LoadInst
*, 8> Loads
;
310 for (auto &I
: *Header
) {
311 auto *Ld
= dyn_cast
<LoadInst
>(&I
);
317 std::map
<LoadInst
*, LoadInst
*> BaseLoads
;
319 for (auto *Ld0
: Loads
) {
320 for (auto *Ld1
: Loads
) {
324 if (AreSequentialAccesses
<LoadInst
>(Ld0
, Ld1
, *DL
, *SE
)) {
325 LoadPairs
[Ld0
] = Ld1
;
326 if (BaseLoads
.count(Ld0
)) {
327 LoadInst
*Base
= BaseLoads
[Ld0
];
328 BaseLoads
[Ld1
] = Base
;
329 SequentialLoads
[Base
].push_back(Ld1
);
331 BaseLoads
[Ld1
] = Ld0
;
332 SequentialLoads
[Ld0
].push_back(Ld1
);
337 return LoadPairs
.size() > 1;
340 void ARMParallelDSP::CreateParallelMACPairs(Reduction
&R
) {
341 OpChainList
&Candidates
= R
.MACCandidates
;
342 PMACPairList
&PMACPairs
= R
.PMACPairs
;
343 const unsigned Elems
= Candidates
.size();
348 auto CanPair
= [&](BinOpChain
*PMul0
, BinOpChain
*PMul1
) {
349 if (!PMul0
->AreSymmetrical(PMul1
))
352 // The first elements of each vector should be loads with sexts. If we
353 // find that its two pairs of consecutive loads, then these can be
354 // transformed into two wider loads and the users can be replaced with
356 for (unsigned x
= 0; x
< PMul0
->LHS
.size(); x
+= 2) {
357 auto *Ld0
= dyn_cast
<LoadInst
>(PMul0
->LHS
[x
]);
358 auto *Ld1
= dyn_cast
<LoadInst
>(PMul1
->LHS
[x
]);
359 auto *Ld2
= dyn_cast
<LoadInst
>(PMul0
->RHS
[x
]);
360 auto *Ld3
= dyn_cast
<LoadInst
>(PMul1
->RHS
[x
]);
362 if (!Ld0
|| !Ld1
|| !Ld2
|| !Ld3
)
365 LLVM_DEBUG(dbgs() << "Looking at operands " << x
<< ":\n"
366 << "\t Ld0: " << *Ld0
<< "\n"
367 << "\t Ld1: " << *Ld1
<< "\n"
368 << "and operands " << x
+ 2 << ":\n"
369 << "\t Ld2: " << *Ld2
<< "\n"
370 << "\t Ld3: " << *Ld3
<< "\n");
372 if (AreSequentialLoads(Ld0
, Ld1
, PMul0
->VecLd
)) {
373 if (AreSequentialLoads(Ld2
, Ld3
, PMul1
->VecLd
)) {
374 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
375 PMACPairs
.push_back(std::make_pair(PMul0
, PMul1
));
377 } else if (AreSequentialLoads(Ld3
, Ld2
, PMul1
->VecLd
)) {
378 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
379 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
380 PMul1
->Exchange
= true;
381 PMACPairs
.push_back(std::make_pair(PMul0
, PMul1
));
384 } else if (AreSequentialLoads(Ld1
, Ld0
, PMul0
->VecLd
) &&
385 AreSequentialLoads(Ld2
, Ld3
, PMul1
->VecLd
)) {
386 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
387 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
388 LLVM_DEBUG(dbgs() << " and swapping muls\n");
389 PMul0
->Exchange
= true;
390 // Only the second operand can be exchanged, so swap the muls.
391 PMACPairs
.push_back(std::make_pair(PMul1
, PMul0
));
398 SmallPtrSet
<const Instruction
*, 4> Paired
;
399 for (unsigned i
= 0; i
< Elems
; ++i
) {
400 BinOpChain
*PMul0
= static_cast<BinOpChain
*>(Candidates
[i
].get());
401 if (Paired
.count(PMul0
->Root
))
404 for (unsigned j
= 0; j
< Elems
; ++j
) {
408 BinOpChain
*PMul1
= static_cast<BinOpChain
*>(Candidates
[j
].get());
409 if (Paired
.count(PMul1
->Root
))
412 const Instruction
*Mul0
= PMul0
->Root
;
413 const Instruction
*Mul1
= PMul1
->Root
;
417 assert(PMul0
!= PMul1
&& "expected different chains");
419 LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
420 dbgs() << "- "; Mul0
->dump();
421 dbgs() << "- "; Mul1
->dump());
423 LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
424 if (CanPair(PMul0
, PMul1
)) {
433 bool ARMParallelDSP::InsertParallelMACs(Reduction
&Reduction
) {
434 Instruction
*Acc
= Reduction
.Phi
;
435 Instruction
*InsertAfter
= Reduction
.AccIntAdd
;
437 for (auto &Pair
: Reduction
.PMACPairs
) {
438 BinOpChain
*PMul0
= Pair
.first
;
439 BinOpChain
*PMul1
= Pair
.second
;
440 LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
441 dbgs() << "- "; PMul0
->Root
->dump();
442 dbgs() << "- "; PMul1
->Root
->dump());
444 auto *VecLd0
= cast
<LoadInst
>(PMul0
->VecLd
[0]);
445 auto *VecLd1
= cast
<LoadInst
>(PMul1
->VecLd
[0]);
446 Acc
= CreateSMLADCall(VecLd0
, VecLd1
, Acc
, PMul1
->Exchange
, InsertAfter
);
450 if (Acc
!= Reduction
.Phi
) {
451 LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc
->dump());
452 Reduction
.AccIntAdd
->replaceAllUsesWith(Acc
);
458 static void MatchReductions(Function
&F
, Loop
*TheLoop
, BasicBlock
*Header
,
459 ReductionList
&Reductions
) {
460 RecurrenceDescriptor RecDesc
;
461 const bool HasFnNoNaNAttr
=
462 F
.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
463 const BasicBlock
*Latch
= TheLoop
->getLoopLatch();
465 // We need a preheader as getIncomingValueForBlock assumes there is one.
466 if (!TheLoop
->getLoopPreheader()) {
467 LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
471 for (PHINode
&Phi
: Header
->phis()) {
472 const auto *Ty
= Phi
.getType();
473 if (!Ty
->isIntegerTy(32) && !Ty
->isIntegerTy(64))
476 const bool IsReduction
=
477 RecurrenceDescriptor::AddReductionVar(&Phi
,
478 RecurrenceDescriptor::RK_IntegerAdd
,
479 TheLoop
, HasFnNoNaNAttr
, RecDesc
);
483 Instruction
*Acc
= dyn_cast
<Instruction
>(Phi
.getIncomingValueForBlock(Latch
));
487 Reductions
.push_back(Reduction(&Phi
, Acc
));
491 dbgs() << "\nAccumulating integer additions (reductions) found:\n";
492 for (auto &R
: Reductions
) {
493 dbgs() << "- "; R
.Phi
->dump();
494 dbgs() << "-> "; R
.AccIntAdd
->dump();
499 static void AddMACCandidate(OpChainList
&Candidates
,
501 Value
*MulOp0
, Value
*MulOp1
) {
502 LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul
->dump());
503 assert(Mul
->getOpcode() == Instruction::Mul
&&
504 "expected mul instruction");
507 if (IsNarrowSequence
<16>(MulOp0
, LHS
) &&
508 IsNarrowSequence
<16>(MulOp1
, RHS
)) {
509 LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul
->dump());
510 Candidates
.push_back(make_unique
<BinOpChain
>(Mul
, LHS
, RHS
));
514 static void MatchParallelMACSequences(Reduction
&R
,
515 OpChainList
&Candidates
) {
516 Instruction
*Acc
= R
.AccIntAdd
;
517 LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc
);
519 // Returns false to signal the search should be stopped.
520 std::function
<bool(Value
*)> Match
=
521 [&Candidates
, &Match
](Value
*V
) -> bool {
523 auto *I
= dyn_cast
<Instruction
>(V
);
527 switch (I
->getOpcode()) {
528 case Instruction::Add
:
529 if (Match(I
->getOperand(0)) || (Match(I
->getOperand(1))))
532 case Instruction::Mul
: {
533 Value
*MulOp0
= I
->getOperand(0);
534 Value
*MulOp1
= I
->getOperand(1);
535 if (isa
<SExtInst
>(MulOp0
) && isa
<SExtInst
>(MulOp1
))
536 AddMACCandidate(Candidates
, I
, MulOp0
, MulOp1
);
539 case Instruction::SExt
:
540 return Match(I
->getOperand(0));
546 LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
547 << Candidates
.size() << " candidates.\n");
550 // Collects all instructions that are not part of the MAC chains, which is the
551 // set of instructions that can potentially alias with the MAC operands.
552 static void AliasCandidates(BasicBlock
*Header
, Instructions
&Reads
,
553 Instructions
&Writes
) {
554 for (auto &I
: *Header
) {
555 if (I
.mayReadFromMemory())
557 if (I
.mayWriteToMemory())
558 Writes
.push_back(&I
);
562 // Check whether statements in the basic block that write to memory alias with
563 // the memory locations accessed by the MAC-chains.
564 // TODO: we need the read statements when we accept more complicated chains.
565 static bool AreAliased(AliasAnalysis
*AA
, Instructions
&Reads
,
566 Instructions
&Writes
, OpChainList
&MACCandidates
) {
567 LLVM_DEBUG(dbgs() << "Alias checks:\n");
568 for (auto &MAC
: MACCandidates
) {
569 LLVM_DEBUG(dbgs() << "mul: "; MAC
->Root
->dump());
571 // At the moment, we allow only simple chains that only consist of reads,
572 // accumulate their result with an integer add, and thus that don't write
573 // memory, and simply bail if they do.
577 // Now for all writes in the basic block, check that they don't alias with
578 // the memory locations accessed by our MAC-chain:
579 for (auto *I
: Writes
) {
580 LLVM_DEBUG(dbgs() << "- "; I
->dump());
581 assert(MAC
->MemLocs
.size() >= 2 && "expecting at least 2 memlocs");
582 for (auto &MemLoc
: MAC
->MemLocs
) {
583 if (isModOrRefSet(intersectModRef(AA
->getModRefInfo(I
, MemLoc
),
584 ModRefInfo::ModRef
))) {
585 LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
592 LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
596 static bool CheckMACMemory(OpChainList
&Candidates
) {
597 for (auto &C
: Candidates
) {
598 // A mul has 2 operands, and a narrow op consist of sext and a load; thus
599 // we expect at least 4 items in this operand value list.
601 LLVM_DEBUG(dbgs() << "Operand list too short.\n");
604 C
->SetMemoryLocations();
605 ValueList
&LHS
= static_cast<BinOpChain
*>(C
.get())->LHS
;
606 ValueList
&RHS
= static_cast<BinOpChain
*>(C
.get())->RHS
;
608 // Use +=2 to skip over the expected extend instructions.
609 for (unsigned i
= 0, e
= LHS
.size(); i
< e
; i
+= 2) {
610 if (!isa
<LoadInst
>(LHS
[i
]) || !isa
<LoadInst
>(RHS
[i
]))
617 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
620 // 1) we first need to find integer add reduction PHIs,
621 // 2) then from the PHI, look for this pattern:
623 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
625 // sext0 = sext i16 %ld0 to i32
627 // sext1 = sext i16 %ld1 to i32
628 // mul0 = mul %sext0, %sext1
630 // sext2 = sext i16 %ld2 to i32
632 // sext3 = sext i16 %ld3 to i32
633 // mul1 = mul i32 %sext2, %sext3
634 // add0 = add i32 %mul0, %acc0
635 // acc1 = add i32 %add0, %mul1
637 // Which can be selected to:
641 // smlad r2, r0, r1, r2
643 // If constants are used instead of loads, these will need to be hoisted
644 // out and into a register.
646 // If loop invariants are used instead of loads, these need to be packed
647 // before the loop begins.
649 bool ARMParallelDSP::MatchSMLAD(Function
&F
) {
650 BasicBlock
*Header
= L
->getHeader();
651 LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
652 dbgs() << "Header block:\n"; Header
->dump();
653 dbgs() << "Loop info:\n\n"; L
->dump());
655 bool Changed
= false;
656 ReductionList Reductions
;
657 MatchReductions(F
, L
, Header
, Reductions
);
659 for (auto &R
: Reductions
) {
660 OpChainList MACCandidates
;
661 MatchParallelMACSequences(R
, MACCandidates
);
662 if (!CheckMACMemory(MACCandidates
))
665 R
.MACCandidates
= std::move(MACCandidates
);
667 LLVM_DEBUG(dbgs() << "MAC candidates:\n";
668 for (auto &M
: R
.MACCandidates
)
673 // Collect all instructions that may read or write memory. Our alias
674 // analysis checks bail out if any of these instructions aliases with an
675 // instruction from the MAC-chain.
676 Instructions Reads
, Writes
;
677 AliasCandidates(Header
, Reads
, Writes
);
679 for (auto &R
: Reductions
) {
680 if (AreAliased(AA
, Reads
, Writes
, R
.MACCandidates
))
682 CreateParallelMACPairs(R
);
683 Changed
|= InsertParallelMACs(R
);
686 LLVM_DEBUG(if (Changed
) dbgs() << "Header block:\n"; Header
->dump(););
690 static LoadInst
*CreateLoadIns(IRBuilder
<NoFolder
> &IRB
, LoadInst
&BaseLoad
,
692 const unsigned AddrSpace
= BaseLoad
.getPointerAddressSpace();
694 Value
*VecPtr
= IRB
.CreateBitCast(BaseLoad
.getPointerOperand(),
695 LoadTy
->getPointerTo(AddrSpace
));
696 return IRB
.CreateAlignedLoad(LoadTy
, VecPtr
, BaseLoad
.getAlignment());
699 Instruction
*ARMParallelDSP::CreateSMLADCall(LoadInst
*VecLd0
, LoadInst
*VecLd1
,
700 Instruction
*Acc
, bool Exchange
,
701 Instruction
*InsertAfter
) {
702 LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
703 << "- " << *VecLd0
<< "\n"
704 << "- " << *VecLd1
<< "\n"
705 << "- " << *Acc
<< "\n"
706 << "Exchange: " << Exchange
<< "\n");
708 IRBuilder
<NoFolder
> Builder(InsertAfter
->getParent(),
709 ++BasicBlock::iterator(InsertAfter
));
711 // Replace the reduction chain with an intrinsic call
712 Type
*Ty
= IntegerType::get(M
->getContext(), 32);
713 LoadInst
*NewLd0
= CreateLoadIns(Builder
, VecLd0
[0], Ty
);
714 LoadInst
*NewLd1
= CreateLoadIns(Builder
, VecLd1
[0], Ty
);
715 Value
* Args
[] = { NewLd0
, NewLd1
, Acc
};
716 Function
*SMLAD
= nullptr;
718 SMLAD
= Acc
->getType()->isIntegerTy(32) ?
719 Intrinsic::getDeclaration(M
, Intrinsic::arm_smladx
) :
720 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlaldx
);
722 SMLAD
= Acc
->getType()->isIntegerTy(32) ?
723 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlad
) :
724 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlald
);
725 CallInst
*Call
= Builder
.CreateCall(SMLAD
, Args
);
730 // Compare the value lists in Other to this chain.
731 bool BinOpChain::AreSymmetrical(BinOpChain
*Other
) {
732 // Element-by-element comparison of Value lists returning true if they are
733 // instructions with the same opcode or constants with the same value.
734 auto CompareValueList
= [](const ValueList
&VL0
,
735 const ValueList
&VL1
) {
736 if (VL0
.size() != VL1
.size()) {
737 LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
738 << VL0
.size() << " != " << VL1
.size() << "\n");
742 const unsigned Pairs
= VL0
.size();
743 LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs
<< "\n");
745 for (unsigned i
= 0; i
< Pairs
; ++i
) {
746 const Value
*V0
= VL0
[i
];
747 const Value
*V1
= VL1
[i
];
748 const auto *Inst0
= dyn_cast
<Instruction
>(V0
);
749 const auto *Inst1
= dyn_cast
<Instruction
>(V1
);
751 LLVM_DEBUG(dbgs() << "Pair " << i
<< ":\n";
752 dbgs() << "mul1: "; V0
->dump();
753 dbgs() << "mul2: "; V1
->dump());
755 if (!Inst0
|| !Inst1
)
758 if (Inst0
->isSameOperationAs(Inst1
)) {
759 LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
763 const APInt
*C0
, *C1
;
764 if (!(match(V0
, m_APInt(C0
)) && match(V1
, m_APInt(C1
)) && C0
== C1
))
768 LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
772 return CompareValueList(LHS
, Other
->LHS
) &&
773 CompareValueList(RHS
, Other
->RHS
);
776 Pass
*llvm::createARMParallelDSPPass() {
777 return new ARMParallelDSP();
780 char ARMParallelDSP::ID
= 0;
782 INITIALIZE_PASS_BEGIN(ARMParallelDSP
, "arm-parallel-dsp",
783 "Transform loops to use DSP intrinsics", false, false)
784 INITIALIZE_PASS_END(ARMParallelDSP
, "arm-parallel-dsp",
785 "Transform loops to use DSP intrinsics", false, false)