1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a Loop Data Prefetching Pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14 #include "llvm/InitializePasses.h"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/IR/CFG.h"
26 #include "llvm/IR/Dominators.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Transforms/Scalar.h"
32 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
33 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
34 #include "llvm/Transforms/Utils/ValueMapper.h"
36 #define DEBUG_TYPE "loop-data-prefetch"
40 // By default, we limit this to creating 16 PHIs (which is a little over half
41 // of the allocatable register set).
43 PrefetchWrites("loop-prefetch-writes", cl::Hidden
, cl::init(false),
44 cl::desc("Prefetch write addresses"));
46 static cl::opt
<unsigned>
47 PrefetchDistance("prefetch-distance",
48 cl::desc("Number of instructions to prefetch ahead"),
51 static cl::opt
<unsigned>
52 MinPrefetchStride("min-prefetch-stride",
53 cl::desc("Min stride to add prefetches"), cl::Hidden
);
55 static cl::opt
<unsigned> MaxPrefetchIterationsAhead(
56 "max-prefetch-iters-ahead",
57 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden
);
59 STATISTIC(NumPrefetches
, "Number of prefetches inserted");
63 /// Loop prefetch implementation class.
64 class LoopDataPrefetch
{
66 LoopDataPrefetch(AssumptionCache
*AC
, DominatorTree
*DT
, LoopInfo
*LI
,
67 ScalarEvolution
*SE
, const TargetTransformInfo
*TTI
,
68 OptimizationRemarkEmitter
*ORE
)
69 : AC(AC
), DT(DT
), LI(LI
), SE(SE
), TTI(TTI
), ORE(ORE
) {}
74 bool runOnLoop(Loop
*L
);
76 /// Check if the stride of the accesses is large enough to
77 /// warrant a prefetch.
78 bool isStrideLargeEnough(const SCEVAddRecExpr
*AR
, unsigned TargetMinStride
);
80 unsigned getMinPrefetchStride(unsigned NumMemAccesses
,
81 unsigned NumStridedMemAccesses
,
82 unsigned NumPrefetches
,
84 if (MinPrefetchStride
.getNumOccurrences() > 0)
85 return MinPrefetchStride
;
86 return TTI
->getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
87 NumPrefetches
, HasCall
);
90 unsigned getPrefetchDistance() {
91 if (PrefetchDistance
.getNumOccurrences() > 0)
92 return PrefetchDistance
;
93 return TTI
->getPrefetchDistance();
96 unsigned getMaxPrefetchIterationsAhead() {
97 if (MaxPrefetchIterationsAhead
.getNumOccurrences() > 0)
98 return MaxPrefetchIterationsAhead
;
99 return TTI
->getMaxPrefetchIterationsAhead();
102 bool doPrefetchWrites() {
103 if (PrefetchWrites
.getNumOccurrences() > 0)
104 return PrefetchWrites
;
105 return TTI
->enableWritePrefetching();
112 const TargetTransformInfo
*TTI
;
113 OptimizationRemarkEmitter
*ORE
;
116 /// Legacy class for inserting loop data prefetches.
117 class LoopDataPrefetchLegacyPass
: public FunctionPass
{
119 static char ID
; // Pass ID, replacement for typeid
120 LoopDataPrefetchLegacyPass() : FunctionPass(ID
) {
121 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
124 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
125 AU
.addRequired
<AssumptionCacheTracker
>();
126 AU
.addRequired
<DominatorTreeWrapperPass
>();
127 AU
.addPreserved
<DominatorTreeWrapperPass
>();
128 AU
.addRequired
<LoopInfoWrapperPass
>();
129 AU
.addPreserved
<LoopInfoWrapperPass
>();
130 AU
.addRequired
<OptimizationRemarkEmitterWrapperPass
>();
131 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
132 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
133 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
136 bool runOnFunction(Function
&F
) override
;
140 char LoopDataPrefetchLegacyPass::ID
= 0;
141 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
142 "Loop Data Prefetch", false, false)
143 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
144 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
145 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
146 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass
)
147 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
148 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
149 "Loop Data Prefetch", false, false)
151 FunctionPass
*llvm::createLoopDataPrefetchPass() {
152 return new LoopDataPrefetchLegacyPass();
155 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr
*AR
,
156 unsigned TargetMinStride
) {
157 // No need to check if any stride goes.
158 if (TargetMinStride
<= 1)
161 const auto *ConstStride
= dyn_cast
<SCEVConstant
>(AR
->getStepRecurrence(*SE
));
162 // If MinStride is set, don't prefetch unless we can ensure that stride is
167 unsigned AbsStride
= std::abs(ConstStride
->getAPInt().getSExtValue());
168 return TargetMinStride
<= AbsStride
;
171 PreservedAnalyses
LoopDataPrefetchPass::run(Function
&F
,
172 FunctionAnalysisManager
&AM
) {
173 DominatorTree
*DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
174 LoopInfo
*LI
= &AM
.getResult
<LoopAnalysis
>(F
);
175 ScalarEvolution
*SE
= &AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
176 AssumptionCache
*AC
= &AM
.getResult
<AssumptionAnalysis
>(F
);
177 OptimizationRemarkEmitter
*ORE
=
178 &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
179 const TargetTransformInfo
*TTI
= &AM
.getResult
<TargetIRAnalysis
>(F
);
181 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
182 bool Changed
= LDP
.run();
185 PreservedAnalyses PA
;
186 PA
.preserve
<DominatorTreeAnalysis
>();
187 PA
.preserve
<LoopAnalysis
>();
191 return PreservedAnalyses::all();
194 bool LoopDataPrefetchLegacyPass::runOnFunction(Function
&F
) {
198 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
199 LoopInfo
*LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
200 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
201 AssumptionCache
*AC
=
202 &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
203 OptimizationRemarkEmitter
*ORE
=
204 &getAnalysis
<OptimizationRemarkEmitterWrapperPass
>().getORE();
205 const TargetTransformInfo
*TTI
=
206 &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
208 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
212 bool LoopDataPrefetch::run() {
213 // If PrefetchDistance is not set, don't run the pass. This gives an
214 // opportunity for targets to run this pass for selected subtargets only
215 // (whose TTI sets PrefetchDistance).
216 if (getPrefetchDistance() == 0)
218 assert(TTI
->getCacheLineSize() && "Cache line size is not set for target");
220 bool MadeChange
= false;
223 for (auto L
= df_begin(I
), LE
= df_end(I
); L
!= LE
; ++L
)
224 MadeChange
|= runOnLoop(*L
);
229 /// A record for a potential prefetch made during the initial scan of the
230 /// loop. This is used to let a single prefetch target multiple memory accesses.
232 /// The address formula for this prefetch as returned by ScalarEvolution.
233 const SCEVAddRecExpr
*LSCEVAddRec
;
234 /// The point of insertion for the prefetch instruction.
235 Instruction
*InsertPt
;
236 /// True if targeting a write memory access.
238 /// The (first seen) prefetched instruction.
241 /// Constructor to create a new Prefetch for \p I.
242 Prefetch(const SCEVAddRecExpr
*L
, Instruction
*I
)
243 : LSCEVAddRec(L
), InsertPt(nullptr), Writes(false), MemI(nullptr) {
247 /// Add the instruction \param I to this prefetch. If it's not the first
248 /// one, 'InsertPt' and 'Writes' will be updated as required.
249 /// \param PtrDiff the known constant address difference to the first added
251 void addInstruction(Instruction
*I
, DominatorTree
*DT
= nullptr,
252 int64_t PtrDiff
= 0) {
256 Writes
= isa
<StoreInst
>(I
);
258 BasicBlock
*PrefBB
= InsertPt
->getParent();
259 BasicBlock
*InsBB
= I
->getParent();
260 if (PrefBB
!= InsBB
) {
261 BasicBlock
*DomBB
= DT
->findNearestCommonDominator(PrefBB
, InsBB
);
263 InsertPt
= DomBB
->getTerminator();
266 if (isa
<StoreInst
>(I
) && PtrDiff
== 0)
272 bool LoopDataPrefetch::runOnLoop(Loop
*L
) {
273 bool MadeChange
= false;
275 // Only prefetch in the inner-most loop
276 if (!L
->isInnermost())
279 SmallPtrSet
<const Value
*, 32> EphValues
;
280 CodeMetrics::collectEphemeralValues(L
, AC
, EphValues
);
282 // Calculate the number of iterations ahead to prefetch
284 bool HasCall
= false;
285 for (const auto BB
: L
->blocks()) {
286 // If the loop already has prefetches, then assume that the user knows
287 // what they are doing and don't add any more.
288 for (auto &I
: *BB
) {
289 if (isa
<CallInst
>(&I
) || isa
<InvokeInst
>(&I
)) {
290 if (const Function
*F
= cast
<CallBase
>(I
).getCalledFunction()) {
291 if (F
->getIntrinsicID() == Intrinsic::prefetch
)
293 if (TTI
->isLoweredToCall(F
))
295 } else { // indirect call.
300 Metrics
.analyzeBasicBlock(BB
, *TTI
, EphValues
);
302 unsigned LoopSize
= Metrics
.NumInsts
;
306 unsigned ItersAhead
= getPrefetchDistance() / LoopSize
;
310 if (ItersAhead
> getMaxPrefetchIterationsAhead())
313 unsigned ConstantMaxTripCount
= SE
->getSmallConstantMaxTripCount(L
);
314 if (ConstantMaxTripCount
&& ConstantMaxTripCount
< ItersAhead
+ 1)
317 unsigned NumMemAccesses
= 0;
318 unsigned NumStridedMemAccesses
= 0;
319 SmallVector
<Prefetch
, 16> Prefetches
;
320 for (const auto BB
: L
->blocks())
321 for (auto &I
: *BB
) {
325 if (LoadInst
*LMemI
= dyn_cast
<LoadInst
>(&I
)) {
327 PtrValue
= LMemI
->getPointerOperand();
328 } else if (StoreInst
*SMemI
= dyn_cast
<StoreInst
>(&I
)) {
329 if (!doPrefetchWrites()) continue;
331 PtrValue
= SMemI
->getPointerOperand();
334 unsigned PtrAddrSpace
= PtrValue
->getType()->getPointerAddressSpace();
338 if (L
->isLoopInvariant(PtrValue
))
341 const SCEV
*LSCEV
= SE
->getSCEV(PtrValue
);
342 const SCEVAddRecExpr
*LSCEVAddRec
= dyn_cast
<SCEVAddRecExpr
>(LSCEV
);
345 NumStridedMemAccesses
++;
347 // We don't want to double prefetch individual cache lines. If this
348 // access is known to be within one cache line of some other one that
349 // has already been prefetched, then don't prefetch this one as well.
350 bool DupPref
= false;
351 for (auto &Pref
: Prefetches
) {
352 const SCEV
*PtrDiff
= SE
->getMinusSCEV(LSCEVAddRec
, Pref
.LSCEVAddRec
);
353 if (const SCEVConstant
*ConstPtrDiff
=
354 dyn_cast
<SCEVConstant
>(PtrDiff
)) {
355 int64_t PD
= std::abs(ConstPtrDiff
->getValue()->getSExtValue());
356 if (PD
< (int64_t) TTI
->getCacheLineSize()) {
357 Pref
.addInstruction(MemI
, DT
, PD
);
364 Prefetches
.push_back(Prefetch(LSCEVAddRec
, MemI
));
367 unsigned TargetMinStride
=
368 getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
369 Prefetches
.size(), HasCall
);
371 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
372 << " iterations ahead (loop size: " << LoopSize
<< ") in "
373 << L
->getHeader()->getParent()->getName() << ": " << *L
);
374 LLVM_DEBUG(dbgs() << "Loop has: "
375 << NumMemAccesses
<< " memory accesses, "
376 << NumStridedMemAccesses
<< " strided memory accesses, "
377 << Prefetches
.size() << " potential prefetch(es), "
378 << "a minimum stride of " << TargetMinStride
<< ", "
379 << (HasCall
? "calls" : "no calls") << ".\n");
381 for (auto &P
: Prefetches
) {
382 // Check if the stride of the accesses is large enough to warrant a
384 if (!isStrideLargeEnough(P
.LSCEVAddRec
, TargetMinStride
))
387 const SCEV
*NextLSCEV
= SE
->getAddExpr(P
.LSCEVAddRec
, SE
->getMulExpr(
388 SE
->getConstant(P
.LSCEVAddRec
->getType(), ItersAhead
),
389 P
.LSCEVAddRec
->getStepRecurrence(*SE
)));
390 if (!isSafeToExpand(NextLSCEV
, *SE
))
393 BasicBlock
*BB
= P
.InsertPt
->getParent();
394 Type
*I8Ptr
= Type::getInt8PtrTy(BB
->getContext(), 0/*PtrAddrSpace*/);
395 SCEVExpander
SCEVE(*SE
, BB
->getModule()->getDataLayout(), "prefaddr");
396 Value
*PrefPtrValue
= SCEVE
.expandCodeFor(NextLSCEV
, I8Ptr
, P
.InsertPt
);
398 IRBuilder
<> Builder(P
.InsertPt
);
399 Module
*M
= BB
->getParent()->getParent();
400 Type
*I32
= Type::getInt32Ty(BB
->getContext());
401 Function
*PrefetchFunc
= Intrinsic::getDeclaration(
402 M
, Intrinsic::prefetch
, PrefPtrValue
->getType());
406 ConstantInt::get(I32
, P
.Writes
),
407 ConstantInt::get(I32
, 3), ConstantInt::get(I32
, 1)});
409 LLVM_DEBUG(dbgs() << " Access: "
410 << *P
.MemI
->getOperand(isa
<LoadInst
>(P
.MemI
) ? 0 : 1)
411 << ", SCEV: " << *P
.LSCEVAddRec
<< "\n");
413 return OptimizationRemark(DEBUG_TYPE
, "Prefetched", P
.MemI
)
414 << "prefetched memory access";