[ORC] Add std::tuple support to SimplePackedSerialization.
[llvm-project.git] / llvm / lib / Transforms / Scalar / AlignmentFromAssumptions.cpp
blobbe21db9087d2e4e1065d7d9311ba187462390afa
1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/Constant.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include "llvm/Transforms/Scalar.h"
39 #define AA_NAME "alignment-from-assumptions"
40 #define DEBUG_TYPE AA_NAME
41 using namespace llvm;
43 STATISTIC(NumLoadAlignChanged,
44 "Number of loads changed by alignment assumptions");
45 STATISTIC(NumStoreAlignChanged,
46 "Number of stores changed by alignment assumptions");
47 STATISTIC(NumMemIntAlignChanged,
48 "Number of memory intrinsics changed by alignment assumptions");
50 namespace {
51 struct AlignmentFromAssumptions : public FunctionPass {
52 static char ID; // Pass identification, replacement for typeid
53 AlignmentFromAssumptions() : FunctionPass(ID) {
54 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
57 bool runOnFunction(Function &F) override;
59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<AssumptionCacheTracker>();
61 AU.addRequired<ScalarEvolutionWrapperPass>();
62 AU.addRequired<DominatorTreeWrapperPass>();
64 AU.setPreservesCFG();
65 AU.addPreserved<AAResultsWrapperPass>();
66 AU.addPreserved<GlobalsAAWrapperPass>();
67 AU.addPreserved<LoopInfoWrapperPass>();
68 AU.addPreserved<DominatorTreeWrapperPass>();
69 AU.addPreserved<ScalarEvolutionWrapperPass>();
72 AlignmentFromAssumptionsPass Impl;
76 char AlignmentFromAssumptions::ID = 0;
77 static const char aip_name[] = "Alignment from assumptions";
78 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
79 aip_name, false, false)
80 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
81 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
82 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
83 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
84 aip_name, false, false)
86 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
87 return new AlignmentFromAssumptions();
90 // Given an expression for the (constant) alignment, AlignSCEV, and an
91 // expression for the displacement between a pointer and the aligned address,
92 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
93 // to a constant. Using SCEV to compute alignment handles the case where
94 // DiffSCEV is a recurrence with constant start such that the aligned offset
95 // is constant. e.g. {16,+,32} % 32 -> 16.
96 static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
97 const SCEV *AlignSCEV,
98 ScalarEvolution *SE) {
99 // DiffUnits = Diff % int64_t(Alignment)
100 const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
102 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
103 << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
105 if (const SCEVConstant *ConstDUSCEV =
106 dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
107 int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
109 // If the displacement is an exact multiple of the alignment, then the
110 // displaced pointer has the same alignment as the aligned pointer, so
111 // return the alignment value.
112 if (!DiffUnits)
113 return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
115 // If the displacement is not an exact multiple, but the remainder is a
116 // constant, then return this remainder (but only if it is a power of 2).
117 uint64_t DiffUnitsAbs = std::abs(DiffUnits);
118 if (isPowerOf2_64(DiffUnitsAbs))
119 return Align(DiffUnitsAbs);
122 return None;
125 // There is an address given by an offset OffSCEV from AASCEV which has an
126 // alignment AlignSCEV. Use that information, if possible, to compute a new
127 // alignment for Ptr.
128 static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
129 const SCEV *OffSCEV, Value *Ptr,
130 ScalarEvolution *SE) {
131 const SCEV *PtrSCEV = SE->getSCEV(Ptr);
132 // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
133 // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
134 // may disagree. Trunc/extend so they agree.
135 PtrSCEV = SE->getTruncateOrZeroExtend(
136 PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
137 const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
138 if (isa<SCEVCouldNotCompute>(DiffSCEV))
139 return Align(1);
141 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
142 // sign-extended OffSCEV to i64, so make sure they agree again.
143 DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
145 // What we really want to know is the overall offset to the aligned
146 // address. This address is displaced by the provided offset.
147 DiffSCEV = SE->getAddExpr(DiffSCEV, OffSCEV);
149 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
150 << *AlignSCEV << " and offset " << *OffSCEV
151 << " using diff " << *DiffSCEV << "\n");
153 if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
154 LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
155 return *NewAlignment;
158 if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
159 // The relative offset to the alignment assumption did not yield a constant,
160 // but we should try harder: if we assume that a is 32-byte aligned, then in
161 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
162 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
163 // As a result, the new alignment will not be a constant, but can still
164 // be improved over the default (of 4) to 16.
166 const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
167 const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
169 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
170 << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
172 // Now compute the new alignment using the displacement to the value in the
173 // first iteration, and also the alignment using the per-iteration delta.
174 // If these are the same, then use that answer. Otherwise, use the smaller
175 // one, but only if it divides the larger one.
176 MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
177 MaybeAlign NewIncAlignment =
178 getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
180 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
181 << "\n");
182 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
183 << "\n");
185 if (!NewAlignment || !NewIncAlignment)
186 return Align(1);
188 const Align NewAlign = *NewAlignment;
189 const Align NewIncAlign = *NewIncAlignment;
190 if (NewAlign > NewIncAlign) {
191 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
192 << DebugStr(NewIncAlign) << "\n");
193 return NewIncAlign;
195 if (NewIncAlign > NewAlign) {
196 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
197 << "\n");
198 return NewAlign;
200 assert(NewIncAlign == NewAlign);
201 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
202 << "\n");
203 return NewAlign;
206 return Align(1);
209 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
210 unsigned Idx,
211 Value *&AAPtr,
212 const SCEV *&AlignSCEV,
213 const SCEV *&OffSCEV) {
214 Type *Int64Ty = Type::getInt64Ty(I->getContext());
215 OperandBundleUse AlignOB = I->getOperandBundleAt(Idx);
216 if (AlignOB.getTagName() != "align")
217 return false;
218 assert(AlignOB.Inputs.size() >= 2);
219 AAPtr = AlignOB.Inputs[0].get();
220 // TODO: Consider accumulating the offset to the base.
221 AAPtr = AAPtr->stripPointerCastsSameRepresentation();
222 AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get());
223 AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);
224 if (AlignOB.Inputs.size() == 3)
225 OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());
226 else
227 OffSCEV = SE->getZero(Int64Ty);
228 OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);
229 return true;
232 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,
233 unsigned Idx) {
234 Value *AAPtr;
235 const SCEV *AlignSCEV, *OffSCEV;
236 if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))
237 return false;
239 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
240 // affect other users.
241 if (isa<ConstantData>(AAPtr))
242 return false;
244 const SCEV *AASCEV = SE->getSCEV(AAPtr);
246 // Apply the assumption to all other users of the specified pointer.
247 SmallPtrSet<Instruction *, 32> Visited;
248 SmallVector<Instruction*, 16> WorkList;
249 for (User *J : AAPtr->users()) {
250 if (J == ACall)
251 continue;
253 if (Instruction *K = dyn_cast<Instruction>(J))
254 WorkList.push_back(K);
257 while (!WorkList.empty()) {
258 Instruction *J = WorkList.pop_back_val();
259 if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
260 if (!isValidAssumeForContext(ACall, J, DT))
261 continue;
262 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
263 LI->getPointerOperand(), SE);
264 if (NewAlignment > LI->getAlign()) {
265 LI->setAlignment(NewAlignment);
266 ++NumLoadAlignChanged;
268 } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
269 if (!isValidAssumeForContext(ACall, J, DT))
270 continue;
271 Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
272 SI->getPointerOperand(), SE);
273 if (NewAlignment > SI->getAlign()) {
274 SI->setAlignment(NewAlignment);
275 ++NumStoreAlignChanged;
277 } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
278 if (!isValidAssumeForContext(ACall, J, DT))
279 continue;
280 Align NewDestAlignment =
281 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
283 LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
284 << "\n";);
285 if (NewDestAlignment > *MI->getDestAlign()) {
286 MI->setDestAlignment(NewDestAlignment);
287 ++NumMemIntAlignChanged;
290 // For memory transfers, there is also a source alignment that
291 // can be set.
292 if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
293 Align NewSrcAlignment =
294 getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
296 LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
297 << "\n";);
299 if (NewSrcAlignment > *MTI->getSourceAlign()) {
300 MTI->setSourceAlignment(NewSrcAlignment);
301 ++NumMemIntAlignChanged;
306 // Now that we've updated that use of the pointer, look for other uses of
307 // the pointer to update.
308 Visited.insert(J);
309 for (User *UJ : J->users()) {
310 Instruction *K = cast<Instruction>(UJ);
311 if (!Visited.count(K))
312 WorkList.push_back(K);
316 return true;
319 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
320 if (skipFunction(F))
321 return false;
323 auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
324 ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
325 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
327 return Impl.runImpl(F, AC, SE, DT);
330 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
331 ScalarEvolution *SE_,
332 DominatorTree *DT_) {
333 SE = SE_;
334 DT = DT_;
336 bool Changed = false;
337 for (auto &AssumeVH : AC.assumptions())
338 if (AssumeVH) {
339 CallInst *Call = cast<CallInst>(AssumeVH);
340 for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
341 Changed |= processAssumption(Call, Idx);
344 return Changed;
347 PreservedAnalyses
348 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
350 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
351 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
352 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
353 if (!runImpl(F, AC, &SE, &DT))
354 return PreservedAnalyses::all();
356 PreservedAnalyses PA;
357 PA.preserveSet<CFGAnalyses>();
358 PA.preserve<ScalarEvolutionAnalysis>();
359 return PA;