1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #define AA_NAME "alignment-from-assumptions"
19 #define DEBUG_TYPE AA_NAME
20 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/Constant.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include "llvm/Transforms/Scalar.h"
39 STATISTIC(NumLoadAlignChanged
,
40 "Number of loads changed by alignment assumptions");
41 STATISTIC(NumStoreAlignChanged
,
42 "Number of stores changed by alignment assumptions");
43 STATISTIC(NumMemIntAlignChanged
,
44 "Number of memory intrinsics changed by alignment assumptions");
47 struct AlignmentFromAssumptions
: public FunctionPass
{
48 static char ID
; // Pass identification, replacement for typeid
49 AlignmentFromAssumptions() : FunctionPass(ID
) {
50 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function
&F
) override
;
55 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
56 AU
.addRequired
<AssumptionCacheTracker
>();
57 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
58 AU
.addRequired
<DominatorTreeWrapperPass
>();
61 AU
.addPreserved
<AAResultsWrapperPass
>();
62 AU
.addPreserved
<GlobalsAAWrapperPass
>();
63 AU
.addPreserved
<LoopInfoWrapperPass
>();
64 AU
.addPreserved
<DominatorTreeWrapperPass
>();
65 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
68 AlignmentFromAssumptionsPass Impl
;
72 char AlignmentFromAssumptions::ID
= 0;
73 static const char aip_name
[] = "Alignment from assumptions";
74 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions
, AA_NAME
,
75 aip_name
, false, false)
76 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
77 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
78 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
79 INITIALIZE_PASS_END(AlignmentFromAssumptions
, AA_NAME
,
80 aip_name
, false, false)
82 FunctionPass
*llvm::createAlignmentFromAssumptionsPass() {
83 return new AlignmentFromAssumptions();
86 // Given an expression for the (constant) alignment, AlignSCEV, and an
87 // expression for the displacement between a pointer and the aligned address,
88 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
89 // to a constant. Using SCEV to compute alignment handles the case where
90 // DiffSCEV is a recurrence with constant start such that the aligned offset
91 // is constant. e.g. {16,+,32} % 32 -> 16.
92 static unsigned getNewAlignmentDiff(const SCEV
*DiffSCEV
,
93 const SCEV
*AlignSCEV
,
94 ScalarEvolution
*SE
) {
95 // DiffUnits = Diff % int64_t(Alignment)
96 const SCEV
*DiffAlignDiv
= SE
->getUDivExpr(DiffSCEV
, AlignSCEV
);
97 const SCEV
*DiffAlign
= SE
->getMulExpr(DiffAlignDiv
, AlignSCEV
);
98 const SCEV
*DiffUnitsSCEV
= SE
->getMinusSCEV(DiffAlign
, DiffSCEV
);
100 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV
<< " is "
101 << *DiffUnitsSCEV
<< " (diff: " << *DiffSCEV
<< ")\n");
103 if (const SCEVConstant
*ConstDUSCEV
=
104 dyn_cast
<SCEVConstant
>(DiffUnitsSCEV
)) {
105 int64_t DiffUnits
= ConstDUSCEV
->getValue()->getSExtValue();
107 // If the displacement is an exact multiple of the alignment, then the
108 // displaced pointer has the same alignment as the aligned pointer, so
109 // return the alignment value.
112 cast
<SCEVConstant
>(AlignSCEV
)->getValue()->getSExtValue();
114 // If the displacement is not an exact multiple, but the remainder is a
115 // constant, then return this remainder (but only if it is a power of 2).
116 uint64_t DiffUnitsAbs
= std::abs(DiffUnits
);
117 if (isPowerOf2_64(DiffUnitsAbs
))
118 return (unsigned) DiffUnitsAbs
;
124 // There is an address given by an offset OffSCEV from AASCEV which has an
125 // alignment AlignSCEV. Use that information, if possible, to compute a new
126 // alignment for Ptr.
127 static unsigned getNewAlignment(const SCEV
*AASCEV
, const SCEV
*AlignSCEV
,
128 const SCEV
*OffSCEV
, Value
*Ptr
,
129 ScalarEvolution
*SE
) {
130 const SCEV
*PtrSCEV
= SE
->getSCEV(Ptr
);
131 const SCEV
*DiffSCEV
= SE
->getMinusSCEV(PtrSCEV
, AASCEV
);
133 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
134 // sign-extended OffSCEV to i64, so make sure they agree again.
135 DiffSCEV
= SE
->getNoopOrSignExtend(DiffSCEV
, OffSCEV
->getType());
137 // What we really want to know is the overall offset to the aligned
138 // address. This address is displaced by the provided offset.
139 DiffSCEV
= SE
->getMinusSCEV(DiffSCEV
, OffSCEV
);
141 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr
<< " relative to "
142 << *AlignSCEV
<< " and offset " << *OffSCEV
143 << " using diff " << *DiffSCEV
<< "\n");
145 unsigned NewAlignment
= getNewAlignmentDiff(DiffSCEV
, AlignSCEV
, SE
);
146 LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment
<< "\n");
150 } else if (const SCEVAddRecExpr
*DiffARSCEV
=
151 dyn_cast
<SCEVAddRecExpr
>(DiffSCEV
)) {
152 // The relative offset to the alignment assumption did not yield a constant,
153 // but we should try harder: if we assume that a is 32-byte aligned, then in
154 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
155 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
156 // As a result, the new alignment will not be a constant, but can still
157 // be improved over the default (of 4) to 16.
159 const SCEV
*DiffStartSCEV
= DiffARSCEV
->getStart();
160 const SCEV
*DiffIncSCEV
= DiffARSCEV
->getStepRecurrence(*SE
);
162 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
163 << *DiffStartSCEV
<< " and inc " << *DiffIncSCEV
<< "\n");
165 // Now compute the new alignment using the displacement to the value in the
166 // first iteration, and also the alignment using the per-iteration delta.
167 // If these are the same, then use that answer. Otherwise, use the smaller
168 // one, but only if it divides the larger one.
169 NewAlignment
= getNewAlignmentDiff(DiffStartSCEV
, AlignSCEV
, SE
);
170 unsigned NewIncAlignment
= getNewAlignmentDiff(DiffIncSCEV
, AlignSCEV
, SE
);
172 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment
<< "\n");
173 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment
<< "\n");
175 if (!NewAlignment
|| !NewIncAlignment
) {
177 } else if (NewAlignment
> NewIncAlignment
) {
178 if (NewAlignment
% NewIncAlignment
== 0) {
179 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment
181 return NewIncAlignment
;
183 } else if (NewIncAlignment
> NewAlignment
) {
184 if (NewIncAlignment
% NewAlignment
== 0) {
185 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
189 } else if (NewIncAlignment
== NewAlignment
) {
190 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
199 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst
*I
,
201 const SCEV
*&AlignSCEV
,
202 const SCEV
*&OffSCEV
) {
203 // An alignment assume must be a statement about the least-significant
204 // bits of the pointer being zero, possibly with some offset.
205 ICmpInst
*ICI
= dyn_cast
<ICmpInst
>(I
->getArgOperand(0));
209 // This must be an expression of the form: x & m == 0.
210 if (ICI
->getPredicate() != ICmpInst::ICMP_EQ
)
213 // Swap things around so that the RHS is 0.
214 Value
*CmpLHS
= ICI
->getOperand(0);
215 Value
*CmpRHS
= ICI
->getOperand(1);
216 const SCEV
*CmpLHSSCEV
= SE
->getSCEV(CmpLHS
);
217 const SCEV
*CmpRHSSCEV
= SE
->getSCEV(CmpRHS
);
218 if (CmpLHSSCEV
->isZero())
219 std::swap(CmpLHS
, CmpRHS
);
220 else if (!CmpRHSSCEV
->isZero())
223 BinaryOperator
*CmpBO
= dyn_cast
<BinaryOperator
>(CmpLHS
);
224 if (!CmpBO
|| CmpBO
->getOpcode() != Instruction::And
)
227 // Swap things around so that the right operand of the and is a constant
228 // (the mask); we cannot deal with variable masks.
229 Value
*AndLHS
= CmpBO
->getOperand(0);
230 Value
*AndRHS
= CmpBO
->getOperand(1);
231 const SCEV
*AndLHSSCEV
= SE
->getSCEV(AndLHS
);
232 const SCEV
*AndRHSSCEV
= SE
->getSCEV(AndRHS
);
233 if (isa
<SCEVConstant
>(AndLHSSCEV
)) {
234 std::swap(AndLHS
, AndRHS
);
235 std::swap(AndLHSSCEV
, AndRHSSCEV
);
238 const SCEVConstant
*MaskSCEV
= dyn_cast
<SCEVConstant
>(AndRHSSCEV
);
242 // The mask must have some trailing ones (otherwise the condition is
243 // trivial and tells us nothing about the alignment of the left operand).
244 unsigned TrailingOnes
= MaskSCEV
->getAPInt().countTrailingOnes();
248 // Cap the alignment at the maximum with which LLVM can deal (and make sure
249 // we don't overflow the shift).
251 TrailingOnes
= std::min(TrailingOnes
,
252 unsigned(sizeof(unsigned) * CHAR_BIT
- 1));
253 Alignment
= std::min(1u << TrailingOnes
, +Value::MaximumAlignment
);
255 Type
*Int64Ty
= Type::getInt64Ty(I
->getParent()->getParent()->getContext());
256 AlignSCEV
= SE
->getConstant(Int64Ty
, Alignment
);
258 // The LHS might be a ptrtoint instruction, or it might be the pointer
262 if (PtrToIntInst
*PToI
= dyn_cast
<PtrToIntInst
>(AndLHS
)) {
263 AAPtr
= PToI
->getPointerOperand();
264 OffSCEV
= SE
->getZero(Int64Ty
);
265 } else if (const SCEVAddExpr
* AndLHSAddSCEV
=
266 dyn_cast
<SCEVAddExpr
>(AndLHSSCEV
)) {
267 // Try to find the ptrtoint; subtract it and the rest is the offset.
268 for (SCEVAddExpr::op_iterator J
= AndLHSAddSCEV
->op_begin(),
269 JE
= AndLHSAddSCEV
->op_end(); J
!= JE
; ++J
)
270 if (const SCEVUnknown
*OpUnk
= dyn_cast
<SCEVUnknown
>(*J
))
271 if (PtrToIntInst
*PToI
= dyn_cast
<PtrToIntInst
>(OpUnk
->getValue())) {
272 AAPtr
= PToI
->getPointerOperand();
273 OffSCEV
= SE
->getMinusSCEV(AndLHSAddSCEV
, *J
);
281 // Sign extend the offset to 64 bits (so that it is like all of the other
283 unsigned OffSCEVBits
= OffSCEV
->getType()->getPrimitiveSizeInBits();
284 if (OffSCEVBits
< 64)
285 OffSCEV
= SE
->getSignExtendExpr(OffSCEV
, Int64Ty
);
286 else if (OffSCEVBits
> 64)
289 AAPtr
= AAPtr
->stripPointerCasts();
293 bool AlignmentFromAssumptionsPass::processAssumption(CallInst
*ACall
) {
295 const SCEV
*AlignSCEV
, *OffSCEV
;
296 if (!extractAlignmentInfo(ACall
, AAPtr
, AlignSCEV
, OffSCEV
))
299 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
300 // affect other users.
301 if (isa
<ConstantData
>(AAPtr
))
304 const SCEV
*AASCEV
= SE
->getSCEV(AAPtr
);
306 // Apply the assumption to all other users of the specified pointer.
307 SmallPtrSet
<Instruction
*, 32> Visited
;
308 SmallVector
<Instruction
*, 16> WorkList
;
309 for (User
*J
: AAPtr
->users()) {
313 if (Instruction
*K
= dyn_cast
<Instruction
>(J
))
314 if (isValidAssumeForContext(ACall
, K
, DT
))
315 WorkList
.push_back(K
);
318 while (!WorkList
.empty()) {
319 Instruction
*J
= WorkList
.pop_back_val();
321 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(J
)) {
322 unsigned NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
323 LI
->getPointerOperand(), SE
);
325 if (NewAlignment
> LI
->getAlignment()) {
326 LI
->setAlignment(NewAlignment
);
327 ++NumLoadAlignChanged
;
329 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(J
)) {
330 unsigned NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
331 SI
->getPointerOperand(), SE
);
333 if (NewAlignment
> SI
->getAlignment()) {
334 SI
->setAlignment(NewAlignment
);
335 ++NumStoreAlignChanged
;
337 } else if (MemIntrinsic
*MI
= dyn_cast
<MemIntrinsic
>(J
)) {
338 unsigned NewDestAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
341 LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment
<< "\n";);
342 if (NewDestAlignment
> MI
->getDestAlignment()) {
343 MI
->setDestAlignment(NewDestAlignment
);
344 ++NumMemIntAlignChanged
;
347 // For memory transfers, there is also a source alignment that
349 if (MemTransferInst
*MTI
= dyn_cast
<MemTransferInst
>(MI
)) {
350 unsigned NewSrcAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
351 MTI
->getSource(), SE
);
353 LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment
<< "\n";);
355 if (NewSrcAlignment
> MTI
->getSourceAlignment()) {
356 MTI
->setSourceAlignment(NewSrcAlignment
);
357 ++NumMemIntAlignChanged
;
362 // Now that we've updated that use of the pointer, look for other uses of
363 // the pointer to update.
365 for (User
*UJ
: J
->users()) {
366 Instruction
*K
= cast
<Instruction
>(UJ
);
367 if (!Visited
.count(K
) && isValidAssumeForContext(ACall
, K
, DT
))
368 WorkList
.push_back(K
);
375 bool AlignmentFromAssumptions::runOnFunction(Function
&F
) {
379 auto &AC
= getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
380 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
381 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
383 return Impl
.runImpl(F
, AC
, SE
, DT
);
386 bool AlignmentFromAssumptionsPass::runImpl(Function
&F
, AssumptionCache
&AC
,
387 ScalarEvolution
*SE_
,
388 DominatorTree
*DT_
) {
392 bool Changed
= false;
393 for (auto &AssumeVH
: AC
.assumptions())
395 Changed
|= processAssumption(cast
<CallInst
>(AssumeVH
));
401 AlignmentFromAssumptionsPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
403 AssumptionCache
&AC
= AM
.getResult
<AssumptionAnalysis
>(F
);
404 ScalarEvolution
&SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
405 DominatorTree
&DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
406 if (!runImpl(F
, AC
, &SE
, &DT
))
407 return PreservedAnalyses::all();
409 PreservedAnalyses PA
;
410 PA
.preserveSet
<CFGAnalyses
>();
411 PA
.preserve
<AAManager
>();
412 PA
.preserve
<ScalarEvolutionAnalysis
>();
413 PA
.preserve
<GlobalsAA
>();