1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11 /// purpose of this pass is do some IR pattern matching to create ACLE
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
13 /// This pass runs only when unaligned accesses is supported/enabled.
15 //===----------------------------------------------------------------------===//
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/SmallPtrSet.h"
19 #include "llvm/Analysis/AliasAnalysis.h"
20 #include "llvm/Analysis/LoopAccessAnalysis.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/NoFolder.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 #include "llvm/Transforms/Utils/LoopUtils.h"
28 #include "llvm/Pass.h"
29 #include "llvm/PassRegistry.h"
30 #include "llvm/PassSupport.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/CodeGen/TargetPassConfig.h"
35 #include "ARMSubtarget.h"
38 using namespace PatternMatch
;
40 #define DEBUG_TYPE "arm-parallel-dsp"
42 STATISTIC(NumSMLAD
, "Number of smlad instructions generated");
45 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden
, cl::init(false),
46 cl::desc("Disable the ARM Parallel DSP pass"));
53 using OpChainList
= SmallVector
<std::unique_ptr
<OpChain
>, 8>;
54 using ReductionList
= SmallVector
<Reduction
, 8>;
55 using ValueList
= SmallVector
<Value
*, 8>;
56 using MemInstList
= SmallVector
<LoadInst
*, 8>;
57 using PMACPair
= std::pair
<BinOpChain
*,BinOpChain
*>;
58 using PMACPairList
= SmallVector
<PMACPair
, 8>;
59 using Instructions
= SmallVector
<Instruction
*,16>;
60 using MemLocList
= SmallVector
<MemoryLocation
, 4>;
65 MemInstList VecLd
; // List of all load instructions.
69 OpChain(Instruction
*I
, ValueList
&vl
) : Root(I
), AllValues(vl
) { }
70 virtual ~OpChain() = default;
72 void PopulateLoads() {
73 for (auto *V
: AllValues
) {
74 if (auto *Ld
= dyn_cast
<LoadInst
>(V
))
79 unsigned size() const { return AllValues
.size(); }
82 // 'BinOpChain' holds the multiplication instructions that are candidates
83 // for parallel execution.
84 struct BinOpChain
: public OpChain
{
85 ValueList LHS
; // List of all (narrow) left hand operands.
86 ValueList RHS
; // List of all (narrow) right hand operands.
87 bool Exchange
= false;
89 BinOpChain(Instruction
*I
, ValueList
&lhs
, ValueList
&rhs
) :
90 OpChain(I
, lhs
), LHS(lhs
), RHS(rhs
) {
92 AllValues
.push_back(V
);
95 bool AreSymmetrical(BinOpChain
*Other
);
98 /// Represent a sequence of multiply-accumulate operations with the aim to
99 /// perform the multiplications in parallel.
101 Instruction
*Root
= nullptr;
102 Value
*Acc
= nullptr;
104 PMACPairList MulPairs
;
105 SmallPtrSet
<Instruction
*, 4> Adds
;
108 Reduction() = delete;
110 Reduction (Instruction
*Add
) : Root(Add
) { }
112 /// Record an Add instruction that is a part of the this reduction.
113 void InsertAdd(Instruction
*I
) { Adds
.insert(I
); }
115 /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
117 void InsertMul(Instruction
*I
, ValueList
&LHS
, ValueList
&RHS
) {
118 Muls
.push_back(make_unique
<BinOpChain
>(I
, LHS
, RHS
));
121 /// Add the incoming accumulator value, returns true if a value had not
122 /// already been added. Returning false signals to the user that this
123 /// reduction already has a value to initialise the accumulator.
124 bool InsertAcc(Value
*V
) {
131 /// Set two BinOpChains, rooted at muls, that can be executed as a single
132 /// parallel operation.
133 void AddMulPair(BinOpChain
*Mul0
, BinOpChain
*Mul1
) {
134 MulPairs
.push_back(std::make_pair(Mul0
, Mul1
));
137 /// Return true if enough mul operations are found that can be executed in
139 bool CreateParallelPairs();
141 /// Return the add instruction which is the root of the reduction.
142 Instruction
*getRoot() { return Root
; }
144 /// Return the incoming value to be accumulated. This maybe null.
145 Value
*getAccumulator() { return Acc
; }
147 /// Return the set of adds that comprise the reduction.
148 SmallPtrSetImpl
<Instruction
*> &getAdds() { return Adds
; }
150 /// Return the BinOpChain, rooted at mul instruction, that comprise the
152 OpChainList
&getMuls() { return Muls
; }
154 /// Return the BinOpChain, rooted at mul instructions, that have been
155 /// paired for parallel execution.
156 PMACPairList
&getMulPairs() { return MulPairs
; }
158 /// To finalise, replace the uses of the root with the intrinsic call.
159 void UpdateRoot(Instruction
*SMLAD
) {
160 Root
->replaceAllUsesWith(SMLAD
);
165 LoadInst
*NewLd
= nullptr;
166 SmallVector
<LoadInst
*, 4> Loads
;
169 WidenedLoad(SmallVectorImpl
<LoadInst
*> &Lds
, LoadInst
*Wide
)
174 LoadInst
*getLoad() {
179 class ARMParallelDSP
: public LoopPass
{
182 TargetLibraryInfo
*TLI
;
186 const DataLayout
*DL
;
188 std::map
<LoadInst
*, LoadInst
*> LoadPairs
;
189 SmallPtrSet
<LoadInst
*, 4> OffsetLoads
;
190 std::map
<LoadInst
*, std::unique_ptr
<WidenedLoad
>> WideLoads
;
193 bool IsNarrowSequence(Value
*V
, ValueList
&VL
);
195 bool RecordMemoryOps(BasicBlock
*BB
);
196 void InsertParallelMACs(Reduction
&Reduction
);
197 bool AreSequentialLoads(LoadInst
*Ld0
, LoadInst
*Ld1
, MemInstList
&VecMem
);
198 LoadInst
* CreateWideLoad(SmallVectorImpl
<LoadInst
*> &Loads
,
199 IntegerType
*LoadTy
);
200 bool CreateParallelPairs(Reduction
&R
);
202 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
203 /// Dual performs two signed 16x16-bit multiplications. It adds the
204 /// products to a 32-bit accumulate operand. Optionally, the instruction can
205 /// exchange the halfwords of the second operand before performing the
207 bool MatchSMLAD(Loop
*L
);
212 ARMParallelDSP() : LoopPass(ID
) { }
214 bool doInitialization(Loop
*L
, LPPassManager
&LPM
) override
{
220 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
221 LoopPass::getAnalysisUsage(AU
);
222 AU
.addRequired
<AssumptionCacheTracker
>();
223 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
224 AU
.addRequired
<AAResultsWrapperPass
>();
225 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
226 AU
.addRequired
<LoopInfoWrapperPass
>();
227 AU
.addRequired
<DominatorTreeWrapperPass
>();
228 AU
.addRequired
<TargetPassConfig
>();
229 AU
.addPreserved
<LoopInfoWrapperPass
>();
230 AU
.setPreservesCFG();
233 bool runOnLoop(Loop
*TheLoop
, LPPassManager
&) override
{
234 if (DisableParallelDSP
)
236 if (skipLoop(TheLoop
))
240 SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
241 AA
= &getAnalysis
<AAResultsWrapperPass
>().getAAResults();
242 TLI
= &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI();
243 DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
244 LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
245 auto &TPC
= getAnalysis
<TargetPassConfig
>();
247 BasicBlock
*Header
= TheLoop
->getHeader();
251 // TODO: We assume the loop header and latch to be the same block.
252 // This is not a fundamental restriction, but lifting this would just
253 // require more work to do the transformation and then patch up the CFG.
254 if (Header
!= TheLoop
->getLoopLatch()) {
255 LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
256 "running pass ARMParallelDSP\n");
260 if (!TheLoop
->getLoopPreheader())
261 InsertPreheaderForLoop(L
, DT
, LI
, nullptr, true);
263 Function
&F
= *Header
->getParent();
265 DL
= &M
->getDataLayout();
267 auto &TM
= TPC
.getTM
<TargetMachine
>();
268 auto *ST
= &TM
.getSubtarget
<ARMSubtarget
>(F
);
270 if (!ST
->allowsUnalignedMem()) {
271 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
272 "running pass ARMParallelDSP\n");
277 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
282 if (!ST
->isLittle()) {
283 LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
284 << "ARMParallelDSP\n");
288 LoopAccessInfo
LAI(L
, SE
, TLI
, AA
, DT
, LI
);
290 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
291 LLVM_DEBUG(dbgs() << " - " << F
.getName() << "\n\n");
293 if (!RecordMemoryOps(Header
)) {
294 LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
298 bool Changes
= MatchSMLAD(L
);
304 template<typename MemInst
>
305 static bool AreSequentialAccesses(MemInst
*MemOp0
, MemInst
*MemOp1
,
306 const DataLayout
&DL
, ScalarEvolution
&SE
) {
307 if (isConsecutiveAccess(MemOp0
, MemOp1
, DL
, SE
))
312 bool ARMParallelDSP::AreSequentialLoads(LoadInst
*Ld0
, LoadInst
*Ld1
,
313 MemInstList
&VecMem
) {
317 if (!LoadPairs
.count(Ld0
) || LoadPairs
[Ld0
] != Ld1
)
320 LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
321 dbgs() << "Ld0:"; Ld0
->dump();
322 dbgs() << "Ld1:"; Ld1
->dump();
326 VecMem
.push_back(Ld0
);
327 VecMem
.push_back(Ld1
);
331 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
332 // instructions, which is set to 16. So here we should collect all i8 and i16
333 // narrow operations.
334 // TODO: we currently only collect i16, and will support i8 later, so that's
335 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
336 template<unsigned MaxBitWidth
>
337 bool ARMParallelDSP::IsNarrowSequence(Value
*V
, ValueList
&VL
) {
340 if (match(V
, m_ConstantInt(CInt
))) {
341 // TODO: if a constant is used, it needs to fit within the bit width.
345 auto *I
= dyn_cast
<Instruction
>(V
);
349 Value
*Val
, *LHS
, *RHS
;
350 if (match(V
, m_Trunc(m_Value(Val
)))) {
351 if (cast
<TruncInst
>(I
)->getDestTy()->getIntegerBitWidth() == MaxBitWidth
)
352 return IsNarrowSequence
<MaxBitWidth
>(Val
, VL
);
353 } else if (match(V
, m_Add(m_Value(LHS
), m_Value(RHS
)))) {
354 // TODO: we need to implement sadd16/sadd8 for this, which enables to
355 // also do the rewrite for smlad8.ll, but it is unsupported for now.
357 } else if (match(V
, m_ZExtOrSExt(m_Value(Val
)))) {
358 if (cast
<CastInst
>(I
)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth
)
361 if (match(Val
, m_Load(m_Value()))) {
362 auto *Ld
= cast
<LoadInst
>(Val
);
364 // Check that these load could be paired.
365 if (!LoadPairs
.count(Ld
) && !OffsetLoads
.count(Ld
))
376 /// Iterate through the block and record base, offset pairs of loads which can
377 /// be widened into a single load.
378 bool ARMParallelDSP::RecordMemoryOps(BasicBlock
*BB
) {
379 SmallVector
<LoadInst
*, 8> Loads
;
380 SmallVector
<Instruction
*, 8> Writes
;
382 // Collect loads and instruction that may write to memory. For now we only
383 // record loads which are simple, sign-extended and have a single user.
384 // TODO: Allow zero-extended loads.
385 for (auto &I
: *BB
) {
386 if (I
.mayWriteToMemory())
387 Writes
.push_back(&I
);
388 auto *Ld
= dyn_cast
<LoadInst
>(&I
);
389 if (!Ld
|| !Ld
->isSimple() ||
390 !Ld
->hasOneUse() || !isa
<SExtInst
>(Ld
->user_back()))
395 using InstSet
= std::set
<Instruction
*>;
396 using DepMap
= std::map
<Instruction
*, InstSet
>;
399 // Record any writes that may alias a load.
400 const auto Size
= LocationSize::unknown();
401 for (auto Read
: Loads
) {
402 for (auto Write
: Writes
) {
403 MemoryLocation ReadLoc
=
404 MemoryLocation(Read
->getPointerOperand(), Size
);
406 if (!isModOrRefSet(intersectModRef(AA
->getModRefInfo(Write
, ReadLoc
),
407 ModRefInfo::ModRef
)))
409 if (DT
->dominates(Write
, Read
))
410 RAWDeps
[Read
].insert(Write
);
414 // Check whether there's not a write between the two loads which would
415 // prevent them from being safely merged.
416 auto SafeToPair
= [&](LoadInst
*Base
, LoadInst
*Offset
) {
417 LoadInst
*Dominator
= DT
->dominates(Base
, Offset
) ? Base
: Offset
;
418 LoadInst
*Dominated
= DT
->dominates(Base
, Offset
) ? Offset
: Base
;
420 if (RAWDeps
.count(Dominated
)) {
421 InstSet
&WritesBefore
= RAWDeps
[Dominated
];
423 for (auto Before
: WritesBefore
) {
425 // We can't move the second load backward, past a write, to merge
426 // with the first load.
427 if (DT
->dominates(Dominator
, Before
))
434 // Record base, offset load pairs.
435 for (auto *Base
: Loads
) {
436 for (auto *Offset
: Loads
) {
440 if (AreSequentialAccesses
<LoadInst
>(Base
, Offset
, *DL
, *SE
) &&
441 SafeToPair(Base
, Offset
)) {
442 LoadPairs
[Base
] = Offset
;
443 OffsetLoads
.insert(Offset
);
449 LLVM_DEBUG(if (!LoadPairs
.empty()) {
450 dbgs() << "Consecutive load pairs:\n";
451 for (auto &MapIt
: LoadPairs
) {
452 LLVM_DEBUG(dbgs() << *MapIt
.first
<< ", "
453 << *MapIt
.second
<< "\n");
456 return LoadPairs
.size() > 1;
459 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
462 // 1) we first need to find integer add then look for this pattern:
466 // sext0 = sext i16 %ld0 to i32
468 // sext1 = sext i16 %ld1 to i32
469 // mul0 = mul %sext0, %sext1
471 // sext2 = sext i16 %ld2 to i32
473 // sext3 = sext i16 %ld3 to i32
474 // mul1 = mul i32 %sext2, %sext3
475 // add0 = add i32 %mul0, %acc0
476 // acc1 = add i32 %add0, %mul1
478 // Which can be selected to:
482 // smlad r2, r0, r1, r2
484 // If constants are used instead of loads, these will need to be hoisted
485 // out and into a register.
487 // If loop invariants are used instead of loads, these need to be packed
488 // before the loop begins.
490 bool ARMParallelDSP::MatchSMLAD(Loop
*L
) {
491 // Search recursively back through the operands to find a tree of values that
492 // form a multiply-accumulate chain. The search records the Add and Mul
493 // instructions that form the reduction and allows us to find a single value
494 // to be used as the initial input to the accumlator.
495 std::function
<bool(Value
*, Reduction
&)> Search
= [&]
496 (Value
*V
, Reduction
&R
) -> bool {
498 // If we find a non-instruction, try to use it as the initial accumulator
499 // value. This may have already been found during the search in which case
500 // this function will return false, signaling a search fail.
501 auto *I
= dyn_cast
<Instruction
>(V
);
503 return R
.InsertAcc(V
);
505 switch (I
->getOpcode()) {
508 case Instruction::PHI
:
509 // Could be the accumulator value.
510 return R
.InsertAcc(V
);
511 case Instruction::Add
: {
512 // Adds should be adding together two muls, or another add and a mul to
513 // be within the mac chain. One of the operands may also be the
514 // accumulator value at which point we should stop searching.
515 bool ValidLHS
= Search(I
->getOperand(0), R
);
516 bool ValidRHS
= Search(I
->getOperand(1), R
);
517 if (!ValidLHS
&& !ValidLHS
)
519 else if (ValidLHS
&& ValidRHS
) {
524 return R
.InsertAcc(I
);
527 case Instruction::Mul
: {
528 Value
*MulOp0
= I
->getOperand(0);
529 Value
*MulOp1
= I
->getOperand(1);
530 if (isa
<SExtInst
>(MulOp0
) && isa
<SExtInst
>(MulOp1
)) {
533 if (IsNarrowSequence
<16>(MulOp0
, LHS
) &&
534 IsNarrowSequence
<16>(MulOp1
, RHS
)) {
535 R
.InsertMul(I
, LHS
, RHS
);
541 case Instruction::SExt
:
542 return Search(I
->getOperand(0), R
);
547 bool Changed
= false;
548 SmallPtrSet
<Instruction
*, 4> AllAdds
;
549 BasicBlock
*Latch
= L
->getLoopLatch();
551 for (Instruction
&I
: reverse(*Latch
)) {
552 if (I
.getOpcode() != Instruction::Add
)
555 if (AllAdds
.count(&I
))
558 const auto *Ty
= I
.getType();
559 if (!Ty
->isIntegerTy(32) && !Ty
->isIntegerTy(64))
566 if (!CreateParallelPairs(R
))
569 InsertParallelMACs(R
);
571 AllAdds
.insert(R
.getAdds().begin(), R
.getAdds().end());
577 bool ARMParallelDSP::CreateParallelPairs(Reduction
&R
) {
579 // Not enough mul operations to make a pair.
580 if (R
.getMuls().size() < 2)
583 // Check that the muls operate directly upon sign extended loads.
584 for (auto &MulChain
: R
.getMuls()) {
585 // A mul has 2 operands, and a narrow op consist of sext and a load; thus
586 // we expect at least 4 items in this operand value list.
587 if (MulChain
->size() < 4) {
588 LLVM_DEBUG(dbgs() << "Operand list too short.\n");
591 MulChain
->PopulateLoads();
592 ValueList
&LHS
= static_cast<BinOpChain
*>(MulChain
.get())->LHS
;
593 ValueList
&RHS
= static_cast<BinOpChain
*>(MulChain
.get())->RHS
;
595 // Use +=2 to skip over the expected extend instructions.
596 for (unsigned i
= 0, e
= LHS
.size(); i
< e
; i
+= 2) {
597 if (!isa
<LoadInst
>(LHS
[i
]) || !isa
<LoadInst
>(RHS
[i
]))
602 auto CanPair
= [&](Reduction
&R
, BinOpChain
*PMul0
, BinOpChain
*PMul1
) {
603 if (!PMul0
->AreSymmetrical(PMul1
))
606 // The first elements of each vector should be loads with sexts. If we
607 // find that its two pairs of consecutive loads, then these can be
608 // transformed into two wider loads and the users can be replaced with
610 for (unsigned x
= 0; x
< PMul0
->LHS
.size(); x
+= 2) {
611 auto *Ld0
= dyn_cast
<LoadInst
>(PMul0
->LHS
[x
]);
612 auto *Ld1
= dyn_cast
<LoadInst
>(PMul1
->LHS
[x
]);
613 auto *Ld2
= dyn_cast
<LoadInst
>(PMul0
->RHS
[x
]);
614 auto *Ld3
= dyn_cast
<LoadInst
>(PMul1
->RHS
[x
]);
616 if (!Ld0
|| !Ld1
|| !Ld2
|| !Ld3
)
619 LLVM_DEBUG(dbgs() << "Loads:\n"
620 << " - " << *Ld0
<< "\n"
621 << " - " << *Ld1
<< "\n"
622 << " - " << *Ld2
<< "\n"
623 << " - " << *Ld3
<< "\n");
625 if (AreSequentialLoads(Ld0
, Ld1
, PMul0
->VecLd
)) {
626 if (AreSequentialLoads(Ld2
, Ld3
, PMul1
->VecLd
)) {
627 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
628 R
.AddMulPair(PMul0
, PMul1
);
630 } else if (AreSequentialLoads(Ld3
, Ld2
, PMul1
->VecLd
)) {
631 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
632 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
633 PMul1
->Exchange
= true;
634 R
.AddMulPair(PMul0
, PMul1
);
637 } else if (AreSequentialLoads(Ld1
, Ld0
, PMul0
->VecLd
) &&
638 AreSequentialLoads(Ld2
, Ld3
, PMul1
->VecLd
)) {
639 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
640 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
641 LLVM_DEBUG(dbgs() << " and swapping muls\n");
642 PMul0
->Exchange
= true;
643 // Only the second operand can be exchanged, so swap the muls.
644 R
.AddMulPair(PMul1
, PMul0
);
651 OpChainList
&Muls
= R
.getMuls();
652 const unsigned Elems
= Muls
.size();
653 SmallPtrSet
<const Instruction
*, 4> Paired
;
654 for (unsigned i
= 0; i
< Elems
; ++i
) {
655 BinOpChain
*PMul0
= static_cast<BinOpChain
*>(Muls
[i
].get());
656 if (Paired
.count(PMul0
->Root
))
659 for (unsigned j
= 0; j
< Elems
; ++j
) {
663 BinOpChain
*PMul1
= static_cast<BinOpChain
*>(Muls
[j
].get());
664 if (Paired
.count(PMul1
->Root
))
667 const Instruction
*Mul0
= PMul0
->Root
;
668 const Instruction
*Mul1
= PMul1
->Root
;
672 assert(PMul0
!= PMul1
&& "expected different chains");
674 if (CanPair(R
, PMul0
, PMul1
)) {
681 return !R
.getMulPairs().empty();
685 void ARMParallelDSP::InsertParallelMACs(Reduction
&R
) {
687 auto CreateSMLADCall
= [&](SmallVectorImpl
<LoadInst
*> &VecLd0
,
688 SmallVectorImpl
<LoadInst
*> &VecLd1
,
689 Value
*Acc
, bool Exchange
,
690 Instruction
*InsertAfter
) {
691 // Replace the reduction chain with an intrinsic call
692 IntegerType
*Ty
= IntegerType::get(M
->getContext(), 32);
693 LoadInst
*WideLd0
= WideLoads
.count(VecLd0
[0]) ?
694 WideLoads
[VecLd0
[0]]->getLoad() : CreateWideLoad(VecLd0
, Ty
);
695 LoadInst
*WideLd1
= WideLoads
.count(VecLd1
[0]) ?
696 WideLoads
[VecLd1
[0]]->getLoad() : CreateWideLoad(VecLd1
, Ty
);
698 Value
* Args
[] = { WideLd0
, WideLd1
, Acc
};
699 Function
*SMLAD
= nullptr;
701 SMLAD
= Acc
->getType()->isIntegerTy(32) ?
702 Intrinsic::getDeclaration(M
, Intrinsic::arm_smladx
) :
703 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlaldx
);
705 SMLAD
= Acc
->getType()->isIntegerTy(32) ?
706 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlad
) :
707 Intrinsic::getDeclaration(M
, Intrinsic::arm_smlald
);
709 IRBuilder
<NoFolder
> Builder(InsertAfter
->getParent(),
710 ++BasicBlock::iterator(InsertAfter
));
711 Instruction
*Call
= Builder
.CreateCall(SMLAD
, Args
);
716 Instruction
*InsertAfter
= R
.getRoot();
717 Value
*Acc
= R
.getAccumulator();
719 Acc
= ConstantInt::get(IntegerType::get(M
->getContext(), 32), 0);
721 LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter
<< "\n"
722 << "Acc: " << *Acc
<< "\n");
723 for (auto &Pair
: R
.getMulPairs()) {
724 BinOpChain
*PMul0
= Pair
.first
;
725 BinOpChain
*PMul1
= Pair
.second
;
726 LLVM_DEBUG(dbgs() << "Muls:\n"
727 << "- " << *PMul0
->Root
<< "\n"
728 << "- " << *PMul1
->Root
<< "\n");
730 Acc
= CreateSMLADCall(PMul0
->VecLd
, PMul1
->VecLd
, Acc
, PMul1
->Exchange
,
732 InsertAfter
= cast
<Instruction
>(Acc
);
734 R
.UpdateRoot(cast
<Instruction
>(Acc
));
737 LoadInst
* ARMParallelDSP::CreateWideLoad(SmallVectorImpl
<LoadInst
*> &Loads
,
738 IntegerType
*LoadTy
) {
739 assert(Loads
.size() == 2 && "currently only support widening two loads");
741 LoadInst
*Base
= Loads
[0];
742 LoadInst
*Offset
= Loads
[1];
744 Instruction
*BaseSExt
= dyn_cast
<SExtInst
>(Base
->user_back());
745 Instruction
*OffsetSExt
= dyn_cast
<SExtInst
>(Offset
->user_back());
747 assert((BaseSExt
&& OffsetSExt
)
748 && "Loads should have a single, extending, user");
750 std::function
<void(Value
*, Value
*)> MoveBefore
=
751 [&](Value
*A
, Value
*B
) -> void {
752 if (!isa
<Instruction
>(A
) || !isa
<Instruction
>(B
))
755 auto *Source
= cast
<Instruction
>(A
);
756 auto *Sink
= cast
<Instruction
>(B
);
758 if (DT
->dominates(Source
, Sink
) ||
759 Source
->getParent() != Sink
->getParent() ||
760 isa
<PHINode
>(Source
) || isa
<PHINode
>(Sink
))
763 Source
->moveBefore(Sink
);
764 for (auto &Op
: Source
->operands())
765 MoveBefore(Op
, Source
);
768 // Insert the load at the point of the original dominating load.
769 LoadInst
*DomLoad
= DT
->dominates(Base
, Offset
) ? Base
: Offset
;
770 IRBuilder
<NoFolder
> IRB(DomLoad
->getParent(),
771 ++BasicBlock::iterator(DomLoad
));
773 // Bitcast the pointer to a wider type and create the wide load, while making
774 // sure to maintain the original alignment as this prevents ldrd from being
775 // generated when it could be illegal due to memory alignment.
776 const unsigned AddrSpace
= DomLoad
->getPointerAddressSpace();
777 Value
*VecPtr
= IRB
.CreateBitCast(Base
->getPointerOperand(),
778 LoadTy
->getPointerTo(AddrSpace
));
779 LoadInst
*WideLoad
= IRB
.CreateAlignedLoad(LoadTy
, VecPtr
,
780 Base
->getAlignment());
782 // Make sure everything is in the correct order in the basic block.
783 MoveBefore(Base
->getPointerOperand(), VecPtr
);
784 MoveBefore(VecPtr
, WideLoad
);
786 // From the wide load, create two values that equal the original two loads.
787 // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
788 // TODO: Support big-endian as well.
789 Value
*Bottom
= IRB
.CreateTrunc(WideLoad
, Base
->getType());
790 BaseSExt
->setOperand(0, Bottom
);
792 IntegerType
*OffsetTy
= cast
<IntegerType
>(Offset
->getType());
793 Value
*ShiftVal
= ConstantInt::get(LoadTy
, OffsetTy
->getBitWidth());
794 Value
*Top
= IRB
.CreateLShr(WideLoad
, ShiftVal
);
795 Value
*Trunc
= IRB
.CreateTrunc(Top
, OffsetTy
);
796 OffsetSExt
->setOperand(0, Trunc
);
798 WideLoads
.emplace(std::make_pair(Base
,
799 make_unique
<WidenedLoad
>(Loads
, WideLoad
)));
803 // Compare the value lists in Other to this chain.
804 bool BinOpChain::AreSymmetrical(BinOpChain
*Other
) {
805 // Element-by-element comparison of Value lists returning true if they are
806 // instructions with the same opcode or constants with the same value.
807 auto CompareValueList
= [](const ValueList
&VL0
,
808 const ValueList
&VL1
) {
809 if (VL0
.size() != VL1
.size()) {
810 LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
811 << VL0
.size() << " != " << VL1
.size() << "\n");
815 const unsigned Pairs
= VL0
.size();
817 for (unsigned i
= 0; i
< Pairs
; ++i
) {
818 const Value
*V0
= VL0
[i
];
819 const Value
*V1
= VL1
[i
];
820 const auto *Inst0
= dyn_cast
<Instruction
>(V0
);
821 const auto *Inst1
= dyn_cast
<Instruction
>(V1
);
823 if (!Inst0
|| !Inst1
)
826 if (Inst0
->isSameOperationAs(Inst1
))
829 const APInt
*C0
, *C1
;
830 if (!(match(V0
, m_APInt(C0
)) && match(V1
, m_APInt(C1
)) && C0
== C1
))
837 return CompareValueList(LHS
, Other
->LHS
) &&
838 CompareValueList(RHS
, Other
->RHS
);
841 Pass
*llvm::createARMParallelDSPPass() {
842 return new ARMParallelDSP();
845 char ARMParallelDSP::ID
= 0;
847 INITIALIZE_PASS_BEGIN(ARMParallelDSP
, "arm-parallel-dsp",
848 "Transform loops to use DSP intrinsics", false, false)
849 INITIALIZE_PASS_END(ARMParallelDSP
, "arm-parallel-dsp",
850 "Transform loops to use DSP intrinsics", false, false)