1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a Loop Data Prefetching Pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
15 #define DEBUG_TYPE "loop-data-prefetch"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpander.h"
24 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/ValueMapper.h"
37 // By default, we limit this to creating 16 PHIs (which is a little over half
38 // of the allocatable register set).
40 PrefetchWrites("loop-prefetch-writes", cl::Hidden
, cl::init(false),
41 cl::desc("Prefetch write addresses"));
43 static cl::opt
<unsigned>
44 PrefetchDistance("prefetch-distance",
45 cl::desc("Number of instructions to prefetch ahead"),
48 static cl::opt
<unsigned>
49 MinPrefetchStride("min-prefetch-stride",
50 cl::desc("Min stride to add prefetches"), cl::Hidden
);
52 static cl::opt
<unsigned> MaxPrefetchIterationsAhead(
53 "max-prefetch-iters-ahead",
54 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden
);
56 STATISTIC(NumPrefetches
, "Number of prefetches inserted");
60 /// Loop prefetch implementation class.
61 class LoopDataPrefetch
{
63 LoopDataPrefetch(AssumptionCache
*AC
, LoopInfo
*LI
, ScalarEvolution
*SE
,
64 const TargetTransformInfo
*TTI
,
65 OptimizationRemarkEmitter
*ORE
)
66 : AC(AC
), LI(LI
), SE(SE
), TTI(TTI
), ORE(ORE
) {}
71 bool runOnLoop(Loop
*L
);
73 /// Check if the stride of the accesses is large enough to
74 /// warrant a prefetch.
75 bool isStrideLargeEnough(const SCEVAddRecExpr
*AR
);
77 unsigned getMinPrefetchStride() {
78 if (MinPrefetchStride
.getNumOccurrences() > 0)
79 return MinPrefetchStride
;
80 return TTI
->getMinPrefetchStride();
83 unsigned getPrefetchDistance() {
84 if (PrefetchDistance
.getNumOccurrences() > 0)
85 return PrefetchDistance
;
86 return TTI
->getPrefetchDistance();
89 unsigned getMaxPrefetchIterationsAhead() {
90 if (MaxPrefetchIterationsAhead
.getNumOccurrences() > 0)
91 return MaxPrefetchIterationsAhead
;
92 return TTI
->getMaxPrefetchIterationsAhead();
98 const TargetTransformInfo
*TTI
;
99 OptimizationRemarkEmitter
*ORE
;
102 /// Legacy class for inserting loop data prefetches.
103 class LoopDataPrefetchLegacyPass
: public FunctionPass
{
105 static char ID
; // Pass ID, replacement for typeid
106 LoopDataPrefetchLegacyPass() : FunctionPass(ID
) {
107 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
110 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
111 AU
.addRequired
<AssumptionCacheTracker
>();
112 AU
.addPreserved
<DominatorTreeWrapperPass
>();
113 AU
.addRequired
<LoopInfoWrapperPass
>();
114 AU
.addPreserved
<LoopInfoWrapperPass
>();
115 AU
.addRequired
<OptimizationRemarkEmitterWrapperPass
>();
116 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
117 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
118 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
121 bool runOnFunction(Function
&F
) override
;
125 char LoopDataPrefetchLegacyPass::ID
= 0;
126 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
127 "Loop Data Prefetch", false, false)
128 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
129 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
130 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
131 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass
)
132 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
133 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass
, "loop-data-prefetch",
134 "Loop Data Prefetch", false, false)
136 FunctionPass
*llvm::createLoopDataPrefetchPass() {
137 return new LoopDataPrefetchLegacyPass();
140 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr
*AR
) {
141 unsigned TargetMinStride
= getMinPrefetchStride();
142 // No need to check if any stride goes.
143 if (TargetMinStride
<= 1)
146 const auto *ConstStride
= dyn_cast
<SCEVConstant
>(AR
->getStepRecurrence(*SE
));
147 // If MinStride is set, don't prefetch unless we can ensure that stride is
152 unsigned AbsStride
= std::abs(ConstStride
->getAPInt().getSExtValue());
153 return TargetMinStride
<= AbsStride
;
156 PreservedAnalyses
LoopDataPrefetchPass::run(Function
&F
,
157 FunctionAnalysisManager
&AM
) {
158 LoopInfo
*LI
= &AM
.getResult
<LoopAnalysis
>(F
);
159 ScalarEvolution
*SE
= &AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
160 AssumptionCache
*AC
= &AM
.getResult
<AssumptionAnalysis
>(F
);
161 OptimizationRemarkEmitter
*ORE
=
162 &AM
.getResult
<OptimizationRemarkEmitterAnalysis
>(F
);
163 const TargetTransformInfo
*TTI
= &AM
.getResult
<TargetIRAnalysis
>(F
);
165 LoopDataPrefetch
LDP(AC
, LI
, SE
, TTI
, ORE
);
166 bool Changed
= LDP
.run();
169 PreservedAnalyses PA
;
170 PA
.preserve
<DominatorTreeAnalysis
>();
171 PA
.preserve
<LoopAnalysis
>();
175 return PreservedAnalyses::all();
178 bool LoopDataPrefetchLegacyPass::runOnFunction(Function
&F
) {
182 LoopInfo
*LI
= &getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
183 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
184 AssumptionCache
*AC
=
185 &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
186 OptimizationRemarkEmitter
*ORE
=
187 &getAnalysis
<OptimizationRemarkEmitterWrapperPass
>().getORE();
188 const TargetTransformInfo
*TTI
=
189 &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
191 LoopDataPrefetch
LDP(AC
, LI
, SE
, TTI
, ORE
);
195 bool LoopDataPrefetch::run() {
196 // If PrefetchDistance is not set, don't run the pass. This gives an
197 // opportunity for targets to run this pass for selected subtargets only
198 // (whose TTI sets PrefetchDistance).
199 if (getPrefetchDistance() == 0)
201 assert(TTI
->getCacheLineSize() && "Cache line size is not set for target");
203 bool MadeChange
= false;
206 for (auto L
= df_begin(I
), LE
= df_end(I
); L
!= LE
; ++L
)
207 MadeChange
|= runOnLoop(*L
);
212 bool LoopDataPrefetch::runOnLoop(Loop
*L
) {
213 bool MadeChange
= false;
215 // Only prefetch in the inner-most loop
219 SmallPtrSet
<const Value
*, 32> EphValues
;
220 CodeMetrics::collectEphemeralValues(L
, AC
, EphValues
);
222 // Calculate the number of iterations ahead to prefetch
224 for (const auto BB
: L
->blocks()) {
225 // If the loop already has prefetches, then assume that the user knows
226 // what they are doing and don't add any more.
228 if (CallInst
*CI
= dyn_cast
<CallInst
>(&I
))
229 if (Function
*F
= CI
->getCalledFunction())
230 if (F
->getIntrinsicID() == Intrinsic::prefetch
)
233 Metrics
.analyzeBasicBlock(BB
, *TTI
, EphValues
);
235 unsigned LoopSize
= Metrics
.NumInsts
;
239 unsigned ItersAhead
= getPrefetchDistance() / LoopSize
;
243 if (ItersAhead
> getMaxPrefetchIterationsAhead())
246 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
247 << " iterations ahead (loop size: " << LoopSize
<< ") in "
248 << L
->getHeader()->getParent()->getName() << ": " << *L
);
250 SmallVector
<std::pair
<Instruction
*, const SCEVAddRecExpr
*>, 16> PrefLoads
;
251 for (const auto BB
: L
->blocks()) {
252 for (auto &I
: *BB
) {
256 if (LoadInst
*LMemI
= dyn_cast
<LoadInst
>(&I
)) {
258 PtrValue
= LMemI
->getPointerOperand();
259 } else if (StoreInst
*SMemI
= dyn_cast
<StoreInst
>(&I
)) {
260 if (!PrefetchWrites
) continue;
262 PtrValue
= SMemI
->getPointerOperand();
265 unsigned PtrAddrSpace
= PtrValue
->getType()->getPointerAddressSpace();
269 if (L
->isLoopInvariant(PtrValue
))
272 const SCEV
*LSCEV
= SE
->getSCEV(PtrValue
);
273 const SCEVAddRecExpr
*LSCEVAddRec
= dyn_cast
<SCEVAddRecExpr
>(LSCEV
);
277 // Check if the stride of the accesses is large enough to warrant a
279 if (!isStrideLargeEnough(LSCEVAddRec
))
282 // We don't want to double prefetch individual cache lines. If this load
283 // is known to be within one cache line of some other load that has
284 // already been prefetched, then don't prefetch this one as well.
285 bool DupPref
= false;
286 for (const auto &PrefLoad
: PrefLoads
) {
287 const SCEV
*PtrDiff
= SE
->getMinusSCEV(LSCEVAddRec
, PrefLoad
.second
);
288 if (const SCEVConstant
*ConstPtrDiff
=
289 dyn_cast
<SCEVConstant
>(PtrDiff
)) {
290 int64_t PD
= std::abs(ConstPtrDiff
->getValue()->getSExtValue());
291 if (PD
< (int64_t) TTI
->getCacheLineSize()) {
300 const SCEV
*NextLSCEV
= SE
->getAddExpr(LSCEVAddRec
, SE
->getMulExpr(
301 SE
->getConstant(LSCEVAddRec
->getType(), ItersAhead
),
302 LSCEVAddRec
->getStepRecurrence(*SE
)));
303 if (!isSafeToExpand(NextLSCEV
, *SE
))
306 PrefLoads
.push_back(std::make_pair(MemI
, LSCEVAddRec
));
308 Type
*I8Ptr
= Type::getInt8PtrTy(BB
->getContext(), PtrAddrSpace
);
309 SCEVExpander
SCEVE(*SE
, I
.getModule()->getDataLayout(), "prefaddr");
310 Value
*PrefPtrValue
= SCEVE
.expandCodeFor(NextLSCEV
, I8Ptr
, MemI
);
312 IRBuilder
<> Builder(MemI
);
313 Module
*M
= BB
->getParent()->getParent();
314 Type
*I32
= Type::getInt32Ty(BB
->getContext());
315 Function
*PrefetchFunc
=
316 Intrinsic::getDeclaration(M
, Intrinsic::prefetch
);
320 ConstantInt::get(I32
, MemI
->mayReadFromMemory() ? 0 : 1),
321 ConstantInt::get(I32
, 3), ConstantInt::get(I32
, 1)});
323 LLVM_DEBUG(dbgs() << " Access: " << *PtrValue
<< ", SCEV: " << *LSCEV
326 return OptimizationRemark(DEBUG_TYPE
, "Prefetched", MemI
)
327 << "prefetched memory access";