1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a Loop Data Prefetching Pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14 #include "llvm/InitializePasses.h"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Transforms/Scalar.h"
31 #include "llvm/Transforms/Utils.h"
32 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
34 #define DEBUG_TYPE "loop-data-prefetch"
38 // By default, we limit this to creating 16 PHIs (which is a little over half
39 // of the allocatable register set).
41 PrefetchWrites("loop-prefetch-writes", cl::Hidden
, cl::init(false),
42 cl::desc("Prefetch write addresses"));
44 static cl::opt
<unsigned>
45 PrefetchDistance("prefetch-distance",
46 cl::desc("Number of instructions to prefetch ahead"),
49 static cl::opt
<unsigned>
50 MinPrefetchStride("min-prefetch-stride",
51 cl::desc("Min stride to add prefetches"), cl::Hidden
);
53 static cl::opt
<unsigned> MaxPrefetchIterationsAhead(
54 "max-prefetch-iters-ahead",
55 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden
);
57 STATISTIC(NumPrefetches
, "Number of prefetches inserted");
61 /// Loop prefetch implementation class.
62 class LoopDataPrefetch
{
64 LoopDataPrefetch(AssumptionCache
*AC
, DominatorTree
*DT
, LoopInfo
*LI
,
65 ScalarEvolution
*SE
, const TargetTransformInfo
*TTI
,
66 OptimizationRemarkEmitter
*ORE
)
67 : AC(AC
), DT(DT
), LI(LI
), SE(SE
), TTI(TTI
), ORE(ORE
) {}
72 bool runOnLoop(Loop
*L
);
74 /// Check if the stride of the accesses is large enough to
75 /// warrant a prefetch.
76 bool isStrideLargeEnough(const SCEVAddRecExpr
*AR
, unsigned TargetMinStride
);
78 unsigned getMinPrefetchStride(unsigned NumMemAccesses
,
79 unsigned NumStridedMemAccesses
,
80 unsigned NumPrefetches
,
82 if (MinPrefetchStride
.getNumOccurrences() > 0)
83 return MinPrefetchStride
;
84 return TTI
->getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
85 NumPrefetches
, HasCall
);
88 unsigned getPrefetchDistance() {
89 if (PrefetchDistance
.getNumOccurrences() > 0)
90 return PrefetchDistance
;
91 return TTI
->getPrefetchDistance();
94 unsigned getMaxPrefetchIterationsAhead() {
95 if (MaxPrefetchIterationsAhead
.getNumOccurrences() > 0)
96 return MaxPrefetchIterationsAhead
;
97 return TTI
->getMaxPrefetchIterationsAhead();
100 bool doPrefetchWrites() {
101 if (PrefetchWrites
.getNumOccurrences() > 0)
102 return PrefetchWrites
;
103 return TTI
->enableWritePrefetching();
110 const TargetTransformInfo
*TTI
;
111 OptimizationRemarkEmitter
*ORE
;
114 /// Legacy class for inserting loop data prefetches.
115 class LoopDataPrefetchLegacyPass
: public FunctionPass
{
117 static char ID
; // Pass ID, replacement for typeid
118 LoopDataPrefetchLegacyPass() : FunctionPass(ID
) {
119 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
122 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
123 AU
.addRequired
<AssumptionCacheTracker
>();
124 AU
.addRequired
<DominatorTreeWrapperPass
>();
125 AU
.addPreserved
<DominatorTreeWrapperPass
>();
126 AU
.addRequired
<LoopInfoWrapperPass
>();
127 AU
.addPreserved
<LoopInfoWrapperPass
>();
128 AU
.addRequiredID(LoopSimplifyID
);
129 AU
.addPreservedID(LoopSimplifyID
);
130 AU
.addRequired
<OptimizationRemarkEmitterWrapperPass
>();
131 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
132 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
133 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
136 bool runOnFunction(Function
&F
) override
;
140 char LoopDataPrefetchLegacyPass::ID
= 0;
141 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
142 "Loop Data Prefetch", false, false)
143 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
144 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
145 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
146 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
147 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass
)
148 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
149 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
150 "Loop Data Prefetch", false, false)
152 FunctionPass
*llvm::createLoopDataPrefetchPass() {
153 return new LoopDataPrefetchLegacyPass();
156 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr
*AR
,
157 unsigned TargetMinStride
) {
158 // No need to check if any stride goes.
159 if (TargetMinStride
<= 1)
162 const auto *ConstStride
= dyn_cast
<SCEVConstant
>(AR
->getStepRecurrence(*SE
));
163 // If MinStride is set, don't prefetch unless we can ensure that stride is
168 unsigned AbsStride
= std::abs(ConstStride
->getAPInt().getSExtValue());
169 return TargetMinStride
<= AbsStride
;
172 PreservedAnalyses
LoopDataPrefetchPass::run(Function
&F
,
173 FunctionAnalysisManager
&AM
) {
174 DominatorTree
*DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
175 LoopInfo
*LI
= &AM
.getResult
<LoopAnalysis
>(F
);
176 ScalarEvolution
*SE
= &AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
177 AssumptionCache
*AC
= &AM
.getResult
<AssumptionAnalysis
>(F
);
178 OptimizationRemarkEmitter
*ORE
=
179 &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
180 const TargetTransformInfo
*TTI
= &AM
.getResult
<TargetIRAnalysis
>(F
);
182 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
183 bool Changed
= LDP
.run();
186 PreservedAnalyses PA
;
187 PA
.preserve
<DominatorTreeAnalysis
>();
188 PA
.preserve
<LoopAnalysis
>();
192 return PreservedAnalyses::all();
195 bool LoopDataPrefetchLegacyPass::runOnFunction(Function
&F
) {
199 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
200 LoopInfo
*LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
201 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
202 AssumptionCache
*AC
=
203 &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
204 OptimizationRemarkEmitter
*ORE
=
205 &getAnalysis
<OptimizationRemarkEmitterWrapperPass
>().getORE();
206 const TargetTransformInfo
*TTI
=
207 &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
209 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
213 bool LoopDataPrefetch::run() {
214 // If PrefetchDistance is not set, don't run the pass. This gives an
215 // opportunity for targets to run this pass for selected subtargets only
216 // (whose TTI sets PrefetchDistance and CacheLineSize).
217 if (getPrefetchDistance() == 0 || TTI
->getCacheLineSize() == 0) {
218 LLVM_DEBUG(dbgs() << "Please set both PrefetchDistance and CacheLineSize "
219 "for loop data prefetch.\n");
223 bool MadeChange
= false;
226 for (Loop
*L
: depth_first(I
))
227 MadeChange
|= runOnLoop(L
);
232 /// A record for a potential prefetch made during the initial scan of the
233 /// loop. This is used to let a single prefetch target multiple memory accesses.
235 /// The address formula for this prefetch as returned by ScalarEvolution.
236 const SCEVAddRecExpr
*LSCEVAddRec
;
237 /// The point of insertion for the prefetch instruction.
238 Instruction
*InsertPt
= nullptr;
239 /// True if targeting a write memory access.
241 /// The (first seen) prefetched instruction.
242 Instruction
*MemI
= nullptr;
244 /// Constructor to create a new Prefetch for \p I.
245 Prefetch(const SCEVAddRecExpr
*L
, Instruction
*I
) : LSCEVAddRec(L
) {
249 /// Add the instruction \param I to this prefetch. If it's not the first
250 /// one, 'InsertPt' and 'Writes' will be updated as required.
251 /// \param PtrDiff the known constant address difference to the first added
253 void addInstruction(Instruction
*I
, DominatorTree
*DT
= nullptr,
254 int64_t PtrDiff
= 0) {
258 Writes
= isa
<StoreInst
>(I
);
260 BasicBlock
*PrefBB
= InsertPt
->getParent();
261 BasicBlock
*InsBB
= I
->getParent();
262 if (PrefBB
!= InsBB
) {
263 BasicBlock
*DomBB
= DT
->findNearestCommonDominator(PrefBB
, InsBB
);
265 InsertPt
= DomBB
->getTerminator();
268 if (isa
<StoreInst
>(I
) && PtrDiff
== 0)
274 bool LoopDataPrefetch::runOnLoop(Loop
*L
) {
275 bool MadeChange
= false;
277 // Only prefetch in the inner-most loop
278 if (!L
->isInnermost())
281 SmallPtrSet
<const Value
*, 32> EphValues
;
282 CodeMetrics::collectEphemeralValues(L
, AC
, EphValues
);
284 // Calculate the number of iterations ahead to prefetch
286 bool HasCall
= false;
287 for (const auto BB
: L
->blocks()) {
288 // If the loop already has prefetches, then assume that the user knows
289 // what they are doing and don't add any more.
290 for (auto &I
: *BB
) {
291 if (isa
<CallInst
>(&I
) || isa
<InvokeInst
>(&I
)) {
292 if (const Function
*F
= cast
<CallBase
>(I
).getCalledFunction()) {
293 if (F
->getIntrinsicID() == Intrinsic::prefetch
)
295 if (TTI
->isLoweredToCall(F
))
297 } else { // indirect call.
302 Metrics
.analyzeBasicBlock(BB
, *TTI
, EphValues
);
305 if (!Metrics
.NumInsts
.isValid())
308 unsigned LoopSize
= *Metrics
.NumInsts
.getValue();
312 unsigned ItersAhead
= getPrefetchDistance() / LoopSize
;
316 if (ItersAhead
> getMaxPrefetchIterationsAhead())
319 unsigned ConstantMaxTripCount
= SE
->getSmallConstantMaxTripCount(L
);
320 if (ConstantMaxTripCount
&& ConstantMaxTripCount
< ItersAhead
+ 1)
323 unsigned NumMemAccesses
= 0;
324 unsigned NumStridedMemAccesses
= 0;
325 SmallVector
<Prefetch
, 16> Prefetches
;
326 for (const auto BB
: L
->blocks())
327 for (auto &I
: *BB
) {
331 if (LoadInst
*LMemI
= dyn_cast
<LoadInst
>(&I
)) {
333 PtrValue
= LMemI
->getPointerOperand();
334 } else if (StoreInst
*SMemI
= dyn_cast
<StoreInst
>(&I
)) {
335 if (!doPrefetchWrites()) continue;
337 PtrValue
= SMemI
->getPointerOperand();
340 unsigned PtrAddrSpace
= PtrValue
->getType()->getPointerAddressSpace();
341 if (!TTI
->shouldPrefetchAddressSpace(PtrAddrSpace
))
344 if (L
->isLoopInvariant(PtrValue
))
347 const SCEV
*LSCEV
= SE
->getSCEV(PtrValue
);
348 const SCEVAddRecExpr
*LSCEVAddRec
= dyn_cast
<SCEVAddRecExpr
>(LSCEV
);
351 NumStridedMemAccesses
++;
353 // We don't want to double prefetch individual cache lines. If this
354 // access is known to be within one cache line of some other one that
355 // has already been prefetched, then don't prefetch this one as well.
356 bool DupPref
= false;
357 for (auto &Pref
: Prefetches
) {
358 const SCEV
*PtrDiff
= SE
->getMinusSCEV(LSCEVAddRec
, Pref
.LSCEVAddRec
);
359 if (const SCEVConstant
*ConstPtrDiff
=
360 dyn_cast
<SCEVConstant
>(PtrDiff
)) {
361 int64_t PD
= std::abs(ConstPtrDiff
->getValue()->getSExtValue());
362 if (PD
< (int64_t) TTI
->getCacheLineSize()) {
363 Pref
.addInstruction(MemI
, DT
, PD
);
370 Prefetches
.push_back(Prefetch(LSCEVAddRec
, MemI
));
373 unsigned TargetMinStride
=
374 getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
375 Prefetches
.size(), HasCall
);
377 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
378 << " iterations ahead (loop size: " << LoopSize
<< ") in "
379 << L
->getHeader()->getParent()->getName() << ": " << *L
);
380 LLVM_DEBUG(dbgs() << "Loop has: "
381 << NumMemAccesses
<< " memory accesses, "
382 << NumStridedMemAccesses
<< " strided memory accesses, "
383 << Prefetches
.size() << " potential prefetch(es), "
384 << "a minimum stride of " << TargetMinStride
<< ", "
385 << (HasCall
? "calls" : "no calls") << ".\n");
387 for (auto &P
: Prefetches
) {
388 // Check if the stride of the accesses is large enough to warrant a
390 if (!isStrideLargeEnough(P
.LSCEVAddRec
, TargetMinStride
))
393 BasicBlock
*BB
= P
.InsertPt
->getParent();
394 SCEVExpander
SCEVE(*SE
, BB
->getModule()->getDataLayout(), "prefaddr");
395 const SCEV
*NextLSCEV
= SE
->getAddExpr(P
.LSCEVAddRec
, SE
->getMulExpr(
396 SE
->getConstant(P
.LSCEVAddRec
->getType(), ItersAhead
),
397 P
.LSCEVAddRec
->getStepRecurrence(*SE
)));
398 if (!SCEVE
.isSafeToExpand(NextLSCEV
))
401 unsigned PtrAddrSpace
= NextLSCEV
->getType()->getPointerAddressSpace();
402 Type
*I8Ptr
= PointerType::get(BB
->getContext(), PtrAddrSpace
);
403 Value
*PrefPtrValue
= SCEVE
.expandCodeFor(NextLSCEV
, I8Ptr
, P
.InsertPt
);
405 IRBuilder
<> Builder(P
.InsertPt
);
406 Module
*M
= BB
->getParent()->getParent();
407 Type
*I32
= Type::getInt32Ty(BB
->getContext());
408 Function
*PrefetchFunc
= Intrinsic::getDeclaration(
409 M
, Intrinsic::prefetch
, PrefPtrValue
->getType());
413 ConstantInt::get(I32
, P
.Writes
),
414 ConstantInt::get(I32
, 3), ConstantInt::get(I32
, 1)});
416 LLVM_DEBUG(dbgs() << " Access: "
417 << *P
.MemI
->getOperand(isa
<LoadInst
>(P
.MemI
) ? 0 : 1)
418 << ", SCEV: " << *P
.LSCEVAddRec
<< "\n");
420 return OptimizationRemark(DEBUG_TYPE
, "Prefetched", P
.MemI
)
421 << "prefetched memory access";