[llvm-exegesis] [NFC] Fixing typo.
[llvm-complete.git] / lib / Transforms / Scalar / LoopDataPrefetch.cpp
blob1fcf1315a1777a1c2aad345a4bd660e9f3700913
1 //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a Loop Data Prefetching Pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/LoopDataPrefetch.h"
15 #define DEBUG_TYPE "loop-data-prefetch"
16 #include "llvm/ADT/DepthFirstIterator.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/ScalarEvolutionExpander.h"
24 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/IR/CFG.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 #include "llvm/Transforms/Utils/ValueMapper.h"
35 using namespace llvm;
37 // By default, we limit this to creating 16 PHIs (which is a little over half
38 // of the allocatable register set).
39 static cl::opt<bool>
40 PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false),
41 cl::desc("Prefetch write addresses"));
43 static cl::opt<unsigned>
44 PrefetchDistance("prefetch-distance",
45 cl::desc("Number of instructions to prefetch ahead"),
46 cl::Hidden);
48 static cl::opt<unsigned>
49 MinPrefetchStride("min-prefetch-stride",
50 cl::desc("Min stride to add prefetches"), cl::Hidden);
52 static cl::opt<unsigned> MaxPrefetchIterationsAhead(
53 "max-prefetch-iters-ahead",
54 cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden);
56 STATISTIC(NumPrefetches, "Number of prefetches inserted");
58 namespace {
60 /// Loop prefetch implementation class.
61 class LoopDataPrefetch {
62 public:
63 LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE,
64 const TargetTransformInfo *TTI,
65 OptimizationRemarkEmitter *ORE)
66 : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}
68 bool run();
70 private:
71 bool runOnLoop(Loop *L);
73 /// Check if the stride of the accesses is large enough to
74 /// warrant a prefetch.
75 bool isStrideLargeEnough(const SCEVAddRecExpr *AR);
77 unsigned getMinPrefetchStride() {
78 if (MinPrefetchStride.getNumOccurrences() > 0)
79 return MinPrefetchStride;
80 return TTI->getMinPrefetchStride();
83 unsigned getPrefetchDistance() {
84 if (PrefetchDistance.getNumOccurrences() > 0)
85 return PrefetchDistance;
86 return TTI->getPrefetchDistance();
89 unsigned getMaxPrefetchIterationsAhead() {
90 if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0)
91 return MaxPrefetchIterationsAhead;
92 return TTI->getMaxPrefetchIterationsAhead();
95 AssumptionCache *AC;
96 LoopInfo *LI;
97 ScalarEvolution *SE;
98 const TargetTransformInfo *TTI;
99 OptimizationRemarkEmitter *ORE;
102 /// Legacy class for inserting loop data prefetches.
103 class LoopDataPrefetchLegacyPass : public FunctionPass {
104 public:
105 static char ID; // Pass ID, replacement for typeid
106 LoopDataPrefetchLegacyPass() : FunctionPass(ID) {
107 initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry());
110 void getAnalysisUsage(AnalysisUsage &AU) const override {
111 AU.addRequired<AssumptionCacheTracker>();
112 AU.addPreserved<DominatorTreeWrapperPass>();
113 AU.addRequired<LoopInfoWrapperPass>();
114 AU.addPreserved<LoopInfoWrapperPass>();
115 AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
116 AU.addRequired<ScalarEvolutionWrapperPass>();
117 AU.addPreserved<ScalarEvolutionWrapperPass>();
118 AU.addRequired<TargetTransformInfoWrapperPass>();
121 bool runOnFunction(Function &F) override;
125 char LoopDataPrefetchLegacyPass::ID = 0;
126 INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
127 "Loop Data Prefetch", false, false)
128 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
129 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
130 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
131 INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
132 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
133 INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch",
134 "Loop Data Prefetch", false, false)
136 FunctionPass *llvm::createLoopDataPrefetchPass() {
137 return new LoopDataPrefetchLegacyPass();
140 bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {
141 unsigned TargetMinStride = getMinPrefetchStride();
142 // No need to check if any stride goes.
143 if (TargetMinStride <= 1)
144 return true;
146 const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE));
147 // If MinStride is set, don't prefetch unless we can ensure that stride is
148 // larger.
149 if (!ConstStride)
150 return false;
152 unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue());
153 return TargetMinStride <= AbsStride;
156 PreservedAnalyses LoopDataPrefetchPass::run(Function &F,
157 FunctionAnalysisManager &AM) {
158 LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);
159 ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
160 AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F);
161 OptimizationRemarkEmitter *ORE =
162 &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
163 const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F);
165 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
166 bool Changed = LDP.run();
168 if (Changed) {
169 PreservedAnalyses PA;
170 PA.preserve<DominatorTreeAnalysis>();
171 PA.preserve<LoopAnalysis>();
172 return PA;
175 return PreservedAnalyses::all();
178 bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {
179 if (skipFunction(F))
180 return false;
182 LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
183 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
184 AssumptionCache *AC =
185 &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
186 OptimizationRemarkEmitter *ORE =
187 &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
188 const TargetTransformInfo *TTI =
189 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
191 LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE);
192 return LDP.run();
195 bool LoopDataPrefetch::run() {
196 // If PrefetchDistance is not set, don't run the pass. This gives an
197 // opportunity for targets to run this pass for selected subtargets only
198 // (whose TTI sets PrefetchDistance).
199 if (getPrefetchDistance() == 0)
200 return false;
201 assert(TTI->getCacheLineSize() && "Cache line size is not set for target");
203 bool MadeChange = false;
205 for (Loop *I : *LI)
206 for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L)
207 MadeChange |= runOnLoop(*L);
209 return MadeChange;
212 bool LoopDataPrefetch::runOnLoop(Loop *L) {
213 bool MadeChange = false;
215 // Only prefetch in the inner-most loop
216 if (!L->empty())
217 return MadeChange;
219 SmallPtrSet<const Value *, 32> EphValues;
220 CodeMetrics::collectEphemeralValues(L, AC, EphValues);
222 // Calculate the number of iterations ahead to prefetch
223 CodeMetrics Metrics;
224 for (const auto BB : L->blocks()) {
225 // If the loop already has prefetches, then assume that the user knows
226 // what they are doing and don't add any more.
227 for (auto &I : *BB)
228 if (CallInst *CI = dyn_cast<CallInst>(&I))
229 if (Function *F = CI->getCalledFunction())
230 if (F->getIntrinsicID() == Intrinsic::prefetch)
231 return MadeChange;
233 Metrics.analyzeBasicBlock(BB, *TTI, EphValues);
235 unsigned LoopSize = Metrics.NumInsts;
236 if (!LoopSize)
237 LoopSize = 1;
239 unsigned ItersAhead = getPrefetchDistance() / LoopSize;
240 if (!ItersAhead)
241 ItersAhead = 1;
243 if (ItersAhead > getMaxPrefetchIterationsAhead())
244 return MadeChange;
246 LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead
247 << " iterations ahead (loop size: " << LoopSize << ") in "
248 << L->getHeader()->getParent()->getName() << ": " << *L);
250 SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads;
251 for (const auto BB : L->blocks()) {
252 for (auto &I : *BB) {
253 Value *PtrValue;
254 Instruction *MemI;
256 if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) {
257 MemI = LMemI;
258 PtrValue = LMemI->getPointerOperand();
259 } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) {
260 if (!PrefetchWrites) continue;
261 MemI = SMemI;
262 PtrValue = SMemI->getPointerOperand();
263 } else continue;
265 unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();
266 if (PtrAddrSpace)
267 continue;
269 if (L->isLoopInvariant(PtrValue))
270 continue;
272 const SCEV *LSCEV = SE->getSCEV(PtrValue);
273 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
274 if (!LSCEVAddRec)
275 continue;
277 // Check if the stride of the accesses is large enough to warrant a
278 // prefetch.
279 if (!isStrideLargeEnough(LSCEVAddRec))
280 continue;
282 // We don't want to double prefetch individual cache lines. If this load
283 // is known to be within one cache line of some other load that has
284 // already been prefetched, then don't prefetch this one as well.
285 bool DupPref = false;
286 for (const auto &PrefLoad : PrefLoads) {
287 const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second);
288 if (const SCEVConstant *ConstPtrDiff =
289 dyn_cast<SCEVConstant>(PtrDiff)) {
290 int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());
291 if (PD < (int64_t) TTI->getCacheLineSize()) {
292 DupPref = true;
293 break;
297 if (DupPref)
298 continue;
300 const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr(
301 SE->getConstant(LSCEVAddRec->getType(), ItersAhead),
302 LSCEVAddRec->getStepRecurrence(*SE)));
303 if (!isSafeToExpand(NextLSCEV, *SE))
304 continue;
306 PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec));
308 Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace);
309 SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr");
310 Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI);
312 IRBuilder<> Builder(MemI);
313 Module *M = BB->getParent()->getParent();
314 Type *I32 = Type::getInt32Ty(BB->getContext());
315 Function *PrefetchFunc =
316 Intrinsic::getDeclaration(M, Intrinsic::prefetch);
317 Builder.CreateCall(
318 PrefetchFunc,
319 {PrefPtrValue,
320 ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1),
321 ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)});
322 ++NumPrefetches;
323 LLVM_DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV
324 << "\n");
325 ORE->emit([&]() {
326 return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI)
327 << "prefetched memory access";
330 MadeChange = true;
334 return MadeChange;