1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a Loop Data Prefetching Pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
14 #include "llvm/InitializePasses.h"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Utils.h"
31 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
33 #define DEBUG_TYPE "loop-data-prefetch"
37 // By default, we limit this to creating 16 PHIs (which is a little over half
38 // of the allocatable register set).
40 PrefetchWrites("loop-prefetch-writes", cl::Hidden
, cl::init(false),
41 cl::desc("Prefetch write addresses"));
43 static cl::opt
<unsigned>
44 PrefetchDistance("prefetch-distance",
45 cl::desc("Number of instructions to prefetch ahead"),
48 static cl::opt
<unsigned>
49 MinPrefetchStride("min-prefetch-stride",
50 cl::desc("Min stride to add prefetches"), cl::Hidden
);
52 static cl::opt
<unsigned> MaxPrefetchIterationsAhead(
53 "max-prefetch-iters-ahead",
54 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden
);
56 STATISTIC(NumPrefetches
, "Number of prefetches inserted");
60 /// Loop prefetch implementation class.
61 class LoopDataPrefetch
{
63 LoopDataPrefetch(AssumptionCache
*AC
, DominatorTree
*DT
, LoopInfo
*LI
,
64 ScalarEvolution
*SE
, const TargetTransformInfo
*TTI
,
65 OptimizationRemarkEmitter
*ORE
)
66 : AC(AC
), DT(DT
), LI(LI
), SE(SE
), TTI(TTI
), ORE(ORE
) {}
71 bool runOnLoop(Loop
*L
);
73 /// Check if the stride of the accesses is large enough to
74 /// warrant a prefetch.
75 bool isStrideLargeEnough(const SCEVAddRecExpr
*AR
, unsigned TargetMinStride
);
77 unsigned getMinPrefetchStride(unsigned NumMemAccesses
,
78 unsigned NumStridedMemAccesses
,
79 unsigned NumPrefetches
,
81 if (MinPrefetchStride
.getNumOccurrences() > 0)
82 return MinPrefetchStride
;
83 return TTI
->getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
84 NumPrefetches
, HasCall
);
87 unsigned getPrefetchDistance() {
88 if (PrefetchDistance
.getNumOccurrences() > 0)
89 return PrefetchDistance
;
90 return TTI
->getPrefetchDistance();
93 unsigned getMaxPrefetchIterationsAhead() {
94 if (MaxPrefetchIterationsAhead
.getNumOccurrences() > 0)
95 return MaxPrefetchIterationsAhead
;
96 return TTI
->getMaxPrefetchIterationsAhead();
99 bool doPrefetchWrites() {
100 if (PrefetchWrites
.getNumOccurrences() > 0)
101 return PrefetchWrites
;
102 return TTI
->enableWritePrefetching();
109 const TargetTransformInfo
*TTI
;
110 OptimizationRemarkEmitter
*ORE
;
113 /// Legacy class for inserting loop data prefetches.
114 class LoopDataPrefetchLegacyPass
: public FunctionPass
{
116 static char ID
; // Pass ID, replacement for typeid
117 LoopDataPrefetchLegacyPass() : FunctionPass(ID
) {
118 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
121 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
122 AU
.addRequired
<AssumptionCacheTracker
>();
123 AU
.addRequired
<DominatorTreeWrapperPass
>();
124 AU
.addPreserved
<DominatorTreeWrapperPass
>();
125 AU
.addRequired
<LoopInfoWrapperPass
>();
126 AU
.addPreserved
<LoopInfoWrapperPass
>();
127 AU
.addRequiredID(LoopSimplifyID
);
128 AU
.addPreservedID(LoopSimplifyID
);
129 AU
.addRequired
<OptimizationRemarkEmitterWrapperPass
>();
130 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
131 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
132 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
135 bool runOnFunction(Function
&F
) override
;
139 char LoopDataPrefetchLegacyPass::ID
= 0;
140 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
141 "Loop Data Prefetch", false, false)
142 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
143 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
144 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
145 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
146 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass
)
147 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
148 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
149 "Loop Data Prefetch", false, false)
151 FunctionPass
*llvm::createLoopDataPrefetchPass() {
152 return new LoopDataPrefetchLegacyPass();
155 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr
*AR
,
156 unsigned TargetMinStride
) {
157 // No need to check if any stride goes.
158 if (TargetMinStride
<= 1)
161 const auto *ConstStride
= dyn_cast
<SCEVConstant
>(AR
->getStepRecurrence(*SE
));
162 // If MinStride is set, don't prefetch unless we can ensure that stride is
167 unsigned AbsStride
= std::abs(ConstStride
->getAPInt().getSExtValue());
168 return TargetMinStride
<= AbsStride
;
171 PreservedAnalyses
LoopDataPrefetchPass::run(Function
&F
,
172 FunctionAnalysisManager
&AM
) {
173 DominatorTree
*DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
174 LoopInfo
*LI
= &AM
.getResult
<LoopAnalysis
>(F
);
175 ScalarEvolution
*SE
= &AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
176 AssumptionCache
*AC
= &AM
.getResult
<AssumptionAnalysis
>(F
);
177 OptimizationRemarkEmitter
*ORE
=
178 &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
179 const TargetTransformInfo
*TTI
= &AM
.getResult
<TargetIRAnalysis
>(F
);
181 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
182 bool Changed
= LDP
.run();
185 PreservedAnalyses PA
;
186 PA
.preserve
<DominatorTreeAnalysis
>();
187 PA
.preserve
<LoopAnalysis
>();
191 return PreservedAnalyses::all();
194 bool LoopDataPrefetchLegacyPass::runOnFunction(Function
&F
) {
198 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
199 LoopInfo
*LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
200 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
201 AssumptionCache
*AC
=
202 &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
203 OptimizationRemarkEmitter
*ORE
=
204 &getAnalysis
<OptimizationRemarkEmitterWrapperPass
>().getORE();
205 const TargetTransformInfo
*TTI
=
206 &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
208 LoopDataPrefetch
LDP(AC
, DT
, LI
, SE
, TTI
, ORE
);
212 bool LoopDataPrefetch::run() {
213 // If PrefetchDistance is not set, don't run the pass. This gives an
214 // opportunity for targets to run this pass for selected subtargets only
215 // (whose TTI sets PrefetchDistance and CacheLineSize).
216 if (getPrefetchDistance() == 0 || TTI
->getCacheLineSize() == 0) {
217 LLVM_DEBUG(dbgs() << "Please set both PrefetchDistance and CacheLineSize "
218 "for loop data prefetch.\n");
222 bool MadeChange
= false;
225 for (Loop
*L
: depth_first(I
))
226 MadeChange
|= runOnLoop(L
);
231 /// A record for a potential prefetch made during the initial scan of the
232 /// loop. This is used to let a single prefetch target multiple memory accesses.
234 /// The address formula for this prefetch as returned by ScalarEvolution.
235 const SCEVAddRecExpr
*LSCEVAddRec
;
236 /// The point of insertion for the prefetch instruction.
237 Instruction
*InsertPt
= nullptr;
238 /// True if targeting a write memory access.
240 /// The (first seen) prefetched instruction.
241 Instruction
*MemI
= nullptr;
243 /// Constructor to create a new Prefetch for \p I.
244 Prefetch(const SCEVAddRecExpr
*L
, Instruction
*I
) : LSCEVAddRec(L
) {
248 /// Add the instruction \param I to this prefetch. If it's not the first
249 /// one, 'InsertPt' and 'Writes' will be updated as required.
250 /// \param PtrDiff the known constant address difference to the first added
252 void addInstruction(Instruction
*I
, DominatorTree
*DT
= nullptr,
253 int64_t PtrDiff
= 0) {
257 Writes
= isa
<StoreInst
>(I
);
259 BasicBlock
*PrefBB
= InsertPt
->getParent();
260 BasicBlock
*InsBB
= I
->getParent();
261 if (PrefBB
!= InsBB
) {
262 BasicBlock
*DomBB
= DT
->findNearestCommonDominator(PrefBB
, InsBB
);
264 InsertPt
= DomBB
->getTerminator();
267 if (isa
<StoreInst
>(I
) && PtrDiff
== 0)
273 bool LoopDataPrefetch::runOnLoop(Loop
*L
) {
274 bool MadeChange
= false;
276 // Only prefetch in the inner-most loop
277 if (!L
->isInnermost())
280 SmallPtrSet
<const Value
*, 32> EphValues
;
281 CodeMetrics::collectEphemeralValues(L
, AC
, EphValues
);
283 // Calculate the number of iterations ahead to prefetch
285 bool HasCall
= false;
286 for (const auto BB
: L
->blocks()) {
287 // If the loop already has prefetches, then assume that the user knows
288 // what they are doing and don't add any more.
289 for (auto &I
: *BB
) {
290 if (isa
<CallInst
>(&I
) || isa
<InvokeInst
>(&I
)) {
291 if (const Function
*F
= cast
<CallBase
>(I
).getCalledFunction()) {
292 if (F
->getIntrinsicID() == Intrinsic::prefetch
)
294 if (TTI
->isLoweredToCall(F
))
296 } else { // indirect call.
301 Metrics
.analyzeBasicBlock(BB
, *TTI
, EphValues
);
304 if (!Metrics
.NumInsts
.isValid())
307 unsigned LoopSize
= *Metrics
.NumInsts
.getValue();
311 unsigned ItersAhead
= getPrefetchDistance() / LoopSize
;
315 if (ItersAhead
> getMaxPrefetchIterationsAhead())
318 unsigned ConstantMaxTripCount
= SE
->getSmallConstantMaxTripCount(L
);
319 if (ConstantMaxTripCount
&& ConstantMaxTripCount
< ItersAhead
+ 1)
322 unsigned NumMemAccesses
= 0;
323 unsigned NumStridedMemAccesses
= 0;
324 SmallVector
<Prefetch
, 16> Prefetches
;
325 for (const auto BB
: L
->blocks())
326 for (auto &I
: *BB
) {
330 if (LoadInst
*LMemI
= dyn_cast
<LoadInst
>(&I
)) {
332 PtrValue
= LMemI
->getPointerOperand();
333 } else if (StoreInst
*SMemI
= dyn_cast
<StoreInst
>(&I
)) {
334 if (!doPrefetchWrites()) continue;
336 PtrValue
= SMemI
->getPointerOperand();
339 unsigned PtrAddrSpace
= PtrValue
->getType()->getPointerAddressSpace();
340 if (!TTI
->shouldPrefetchAddressSpace(PtrAddrSpace
))
343 if (L
->isLoopInvariant(PtrValue
))
346 const SCEV
*LSCEV
= SE
->getSCEV(PtrValue
);
347 const SCEVAddRecExpr
*LSCEVAddRec
= dyn_cast
<SCEVAddRecExpr
>(LSCEV
);
350 NumStridedMemAccesses
++;
352 // We don't want to double prefetch individual cache lines. If this
353 // access is known to be within one cache line of some other one that
354 // has already been prefetched, then don't prefetch this one as well.
355 bool DupPref
= false;
356 for (auto &Pref
: Prefetches
) {
357 const SCEV
*PtrDiff
= SE
->getMinusSCEV(LSCEVAddRec
, Pref
.LSCEVAddRec
);
358 if (const SCEVConstant
*ConstPtrDiff
=
359 dyn_cast
<SCEVConstant
>(PtrDiff
)) {
360 int64_t PD
= std::abs(ConstPtrDiff
->getValue()->getSExtValue());
361 if (PD
< (int64_t) TTI
->getCacheLineSize()) {
362 Pref
.addInstruction(MemI
, DT
, PD
);
369 Prefetches
.push_back(Prefetch(LSCEVAddRec
, MemI
));
372 unsigned TargetMinStride
=
373 getMinPrefetchStride(NumMemAccesses
, NumStridedMemAccesses
,
374 Prefetches
.size(), HasCall
);
376 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
377 << " iterations ahead (loop size: " << LoopSize
<< ") in "
378 << L
->getHeader()->getParent()->getName() << ": " << *L
);
379 LLVM_DEBUG(dbgs() << "Loop has: "
380 << NumMemAccesses
<< " memory accesses, "
381 << NumStridedMemAccesses
<< " strided memory accesses, "
382 << Prefetches
.size() << " potential prefetch(es), "
383 << "a minimum stride of " << TargetMinStride
<< ", "
384 << (HasCall
? "calls" : "no calls") << ".\n");
386 for (auto &P
: Prefetches
) {
387 // Check if the stride of the accesses is large enough to warrant a
389 if (!isStrideLargeEnough(P
.LSCEVAddRec
, TargetMinStride
))
392 BasicBlock
*BB
= P
.InsertPt
->getParent();
393 SCEVExpander
SCEVE(*SE
, BB
->getDataLayout(), "prefaddr");
394 const SCEV
*NextLSCEV
= SE
->getAddExpr(P
.LSCEVAddRec
, SE
->getMulExpr(
395 SE
->getConstant(P
.LSCEVAddRec
->getType(), ItersAhead
),
396 P
.LSCEVAddRec
->getStepRecurrence(*SE
)));
397 if (!SCEVE
.isSafeToExpand(NextLSCEV
))
400 unsigned PtrAddrSpace
= NextLSCEV
->getType()->getPointerAddressSpace();
401 Type
*I8Ptr
= PointerType::get(BB
->getContext(), PtrAddrSpace
);
402 Value
*PrefPtrValue
= SCEVE
.expandCodeFor(NextLSCEV
, I8Ptr
, P
.InsertPt
);
404 IRBuilder
<> Builder(P
.InsertPt
);
405 Type
*I32
= Type::getInt32Ty(BB
->getContext());
406 Builder
.CreateIntrinsic(Intrinsic::prefetch
, PrefPtrValue
->getType(),
407 {PrefPtrValue
, ConstantInt::get(I32
, P
.Writes
),
408 ConstantInt::get(I32
, 3),
409 ConstantInt::get(I32
, 1)});
411 LLVM_DEBUG(dbgs() << " Access: "
412 << *P
.MemI
->getOperand(isa
<LoadInst
>(P
.MemI
) ? 0 : 1)
413 << ", SCEV: " << *P
.LSCEVAddRec
<< "\n");
415 return OptimizationRemark(DEBUG_TYPE
, "Prefetched", P
.MemI
)
416 << "prefetched memory access";