1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/Constant.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/Intrinsics.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include "llvm/Transforms/Scalar.h"
39 #define AA_NAME "alignment-from-assumptions"
40 #define DEBUG_TYPE AA_NAME
43 STATISTIC(NumLoadAlignChanged
,
44 "Number of loads changed by alignment assumptions");
45 STATISTIC(NumStoreAlignChanged
,
46 "Number of stores changed by alignment assumptions");
47 STATISTIC(NumMemIntAlignChanged
,
48 "Number of memory intrinsics changed by alignment assumptions");
51 struct AlignmentFromAssumptions
: public FunctionPass
{
52 static char ID
; // Pass identification, replacement for typeid
53 AlignmentFromAssumptions() : FunctionPass(ID
) {
54 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
57 bool runOnFunction(Function
&F
) override
;
59 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
60 AU
.addRequired
<AssumptionCacheTracker
>();
61 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
62 AU
.addRequired
<DominatorTreeWrapperPass
>();
65 AU
.addPreserved
<AAResultsWrapperPass
>();
66 AU
.addPreserved
<GlobalsAAWrapperPass
>();
67 AU
.addPreserved
<LoopInfoWrapperPass
>();
68 AU
.addPreserved
<DominatorTreeWrapperPass
>();
69 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
72 AlignmentFromAssumptionsPass Impl
;
76 char AlignmentFromAssumptions::ID
= 0;
77 static const char aip_name
[] = "Alignment from assumptions";
78 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions
, AA_NAME
,
79 aip_name
, false, false)
80 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
81 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
82 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
83 INITIALIZE_PASS_END(AlignmentFromAssumptions
, AA_NAME
,
84 aip_name
, false, false)
86 FunctionPass
*llvm::createAlignmentFromAssumptionsPass() {
87 return new AlignmentFromAssumptions();
90 // Given an expression for the (constant) alignment, AlignSCEV, and an
91 // expression for the displacement between a pointer and the aligned address,
92 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
93 // to a constant. Using SCEV to compute alignment handles the case where
94 // DiffSCEV is a recurrence with constant start such that the aligned offset
95 // is constant. e.g. {16,+,32} % 32 -> 16.
96 static MaybeAlign
getNewAlignmentDiff(const SCEV
*DiffSCEV
,
97 const SCEV
*AlignSCEV
,
98 ScalarEvolution
*SE
) {
99 // DiffUnits = Diff % int64_t(Alignment)
100 const SCEV
*DiffUnitsSCEV
= SE
->getURemExpr(DiffSCEV
, AlignSCEV
);
102 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV
<< " is "
103 << *DiffUnitsSCEV
<< " (diff: " << *DiffSCEV
<< ")\n");
105 if (const SCEVConstant
*ConstDUSCEV
=
106 dyn_cast
<SCEVConstant
>(DiffUnitsSCEV
)) {
107 int64_t DiffUnits
= ConstDUSCEV
->getValue()->getSExtValue();
109 // If the displacement is an exact multiple of the alignment, then the
110 // displaced pointer has the same alignment as the aligned pointer, so
111 // return the alignment value.
113 return cast
<SCEVConstant
>(AlignSCEV
)->getValue()->getAlignValue();
115 // If the displacement is not an exact multiple, but the remainder is a
116 // constant, then return this remainder (but only if it is a power of 2).
117 uint64_t DiffUnitsAbs
= std::abs(DiffUnits
);
118 if (isPowerOf2_64(DiffUnitsAbs
))
119 return Align(DiffUnitsAbs
);
125 // There is an address given by an offset OffSCEV from AASCEV which has an
126 // alignment AlignSCEV. Use that information, if possible, to compute a new
127 // alignment for Ptr.
128 static Align
getNewAlignment(const SCEV
*AASCEV
, const SCEV
*AlignSCEV
,
129 const SCEV
*OffSCEV
, Value
*Ptr
,
130 ScalarEvolution
*SE
) {
131 const SCEV
*PtrSCEV
= SE
->getSCEV(Ptr
);
132 // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
133 // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
134 // may disagree. Trunc/extend so they agree.
135 PtrSCEV
= SE
->getTruncateOrZeroExtend(
136 PtrSCEV
, SE
->getEffectiveSCEVType(AASCEV
->getType()));
137 const SCEV
*DiffSCEV
= SE
->getMinusSCEV(PtrSCEV
, AASCEV
);
138 if (isa
<SCEVCouldNotCompute
>(DiffSCEV
))
141 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
142 // sign-extended OffSCEV to i64, so make sure they agree again.
143 DiffSCEV
= SE
->getNoopOrSignExtend(DiffSCEV
, OffSCEV
->getType());
145 // What we really want to know is the overall offset to the aligned
146 // address. This address is displaced by the provided offset.
147 DiffSCEV
= SE
->getAddExpr(DiffSCEV
, OffSCEV
);
149 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr
<< " relative to "
150 << *AlignSCEV
<< " and offset " << *OffSCEV
151 << " using diff " << *DiffSCEV
<< "\n");
153 if (MaybeAlign NewAlignment
= getNewAlignmentDiff(DiffSCEV
, AlignSCEV
, SE
)) {
154 LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment
) << "\n");
155 return *NewAlignment
;
158 if (const SCEVAddRecExpr
*DiffARSCEV
= dyn_cast
<SCEVAddRecExpr
>(DiffSCEV
)) {
159 // The relative offset to the alignment assumption did not yield a constant,
160 // but we should try harder: if we assume that a is 32-byte aligned, then in
161 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
162 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
163 // As a result, the new alignment will not be a constant, but can still
164 // be improved over the default (of 4) to 16.
166 const SCEV
*DiffStartSCEV
= DiffARSCEV
->getStart();
167 const SCEV
*DiffIncSCEV
= DiffARSCEV
->getStepRecurrence(*SE
);
169 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
170 << *DiffStartSCEV
<< " and inc " << *DiffIncSCEV
<< "\n");
172 // Now compute the new alignment using the displacement to the value in the
173 // first iteration, and also the alignment using the per-iteration delta.
174 // If these are the same, then use that answer. Otherwise, use the smaller
175 // one, but only if it divides the larger one.
176 MaybeAlign NewAlignment
= getNewAlignmentDiff(DiffStartSCEV
, AlignSCEV
, SE
);
177 MaybeAlign NewIncAlignment
=
178 getNewAlignmentDiff(DiffIncSCEV
, AlignSCEV
, SE
);
180 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment
)
182 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment
)
185 if (!NewAlignment
|| !NewIncAlignment
)
188 const Align NewAlign
= *NewAlignment
;
189 const Align NewIncAlign
= *NewIncAlignment
;
190 if (NewAlign
> NewIncAlign
) {
191 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
192 << DebugStr(NewIncAlign
) << "\n");
195 if (NewIncAlign
> NewAlign
) {
196 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign
)
200 assert(NewIncAlign
== NewAlign
);
201 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign
)
209 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst
*I
,
212 const SCEV
*&AlignSCEV
,
213 const SCEV
*&OffSCEV
) {
214 Type
*Int64Ty
= Type::getInt64Ty(I
->getContext());
215 OperandBundleUse AlignOB
= I
->getOperandBundleAt(Idx
);
216 if (AlignOB
.getTagName() != "align")
218 assert(AlignOB
.Inputs
.size() >= 2);
219 AAPtr
= AlignOB
.Inputs
[0].get();
220 // TODO: Consider accumulating the offset to the base.
221 AAPtr
= AAPtr
->stripPointerCastsSameRepresentation();
222 AlignSCEV
= SE
->getSCEV(AlignOB
.Inputs
[1].get());
223 AlignSCEV
= SE
->getTruncateOrZeroExtend(AlignSCEV
, Int64Ty
);
224 if (AlignOB
.Inputs
.size() == 3)
225 OffSCEV
= SE
->getSCEV(AlignOB
.Inputs
[2].get());
227 OffSCEV
= SE
->getZero(Int64Ty
);
228 OffSCEV
= SE
->getTruncateOrZeroExtend(OffSCEV
, Int64Ty
);
232 bool AlignmentFromAssumptionsPass::processAssumption(CallInst
*ACall
,
235 const SCEV
*AlignSCEV
, *OffSCEV
;
236 if (!extractAlignmentInfo(ACall
, Idx
, AAPtr
, AlignSCEV
, OffSCEV
))
239 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
240 // affect other users.
241 if (isa
<ConstantData
>(AAPtr
))
244 const SCEV
*AASCEV
= SE
->getSCEV(AAPtr
);
246 // Apply the assumption to all other users of the specified pointer.
247 SmallPtrSet
<Instruction
*, 32> Visited
;
248 SmallVector
<Instruction
*, 16> WorkList
;
249 for (User
*J
: AAPtr
->users()) {
253 if (Instruction
*K
= dyn_cast
<Instruction
>(J
))
254 WorkList
.push_back(K
);
257 while (!WorkList
.empty()) {
258 Instruction
*J
= WorkList
.pop_back_val();
259 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(J
)) {
260 if (!isValidAssumeForContext(ACall
, J
, DT
))
262 Align NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
263 LI
->getPointerOperand(), SE
);
264 if (NewAlignment
> LI
->getAlign()) {
265 LI
->setAlignment(NewAlignment
);
266 ++NumLoadAlignChanged
;
268 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(J
)) {
269 if (!isValidAssumeForContext(ACall
, J
, DT
))
271 Align NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
272 SI
->getPointerOperand(), SE
);
273 if (NewAlignment
> SI
->getAlign()) {
274 SI
->setAlignment(NewAlignment
);
275 ++NumStoreAlignChanged
;
277 } else if (MemIntrinsic
*MI
= dyn_cast
<MemIntrinsic
>(J
)) {
278 if (!isValidAssumeForContext(ACall
, J
, DT
))
280 Align NewDestAlignment
=
281 getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
, MI
->getDest(), SE
);
283 LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment
)
285 if (NewDestAlignment
> *MI
->getDestAlign()) {
286 MI
->setDestAlignment(NewDestAlignment
);
287 ++NumMemIntAlignChanged
;
290 // For memory transfers, there is also a source alignment that
292 if (MemTransferInst
*MTI
= dyn_cast
<MemTransferInst
>(MI
)) {
293 Align NewSrcAlignment
=
294 getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
, MTI
->getSource(), SE
);
296 LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment
)
299 if (NewSrcAlignment
> *MTI
->getSourceAlign()) {
300 MTI
->setSourceAlignment(NewSrcAlignment
);
301 ++NumMemIntAlignChanged
;
306 // Now that we've updated that use of the pointer, look for other uses of
307 // the pointer to update.
309 for (User
*UJ
: J
->users()) {
310 Instruction
*K
= cast
<Instruction
>(UJ
);
311 if (!Visited
.count(K
))
312 WorkList
.push_back(K
);
319 bool AlignmentFromAssumptions::runOnFunction(Function
&F
) {
323 auto &AC
= getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
324 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
325 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
327 return Impl
.runImpl(F
, AC
, SE
, DT
);
330 bool AlignmentFromAssumptionsPass::runImpl(Function
&F
, AssumptionCache
&AC
,
331 ScalarEvolution
*SE_
,
332 DominatorTree
*DT_
) {
336 bool Changed
= false;
337 for (auto &AssumeVH
: AC
.assumptions())
339 CallInst
*Call
= cast
<CallInst
>(AssumeVH
);
340 for (unsigned Idx
= 0; Idx
< Call
->getNumOperandBundles(); Idx
++)
341 Changed
|= processAssumption(Call
, Idx
);
348 AlignmentFromAssumptionsPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
350 AssumptionCache
&AC
= AM
.getResult
<AssumptionAnalysis
>(F
);
351 ScalarEvolution
&SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
352 DominatorTree
&DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
353 if (!runImpl(F
, AC
, &SE
, &DT
))
354 return PreservedAnalyses::all();
356 PreservedAnalyses PA
;
357 PA
.preserveSet
<CFGAnalyses
>();
358 PA
.preserve
<ScalarEvolutionAnalysis
>();