1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #define AA_NAME "alignment-from-assumptions"
19 #define DEBUG_TYPE AA_NAME
20 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/Analysis/AliasAnalysis.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/LoopInfo.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/Constant.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/Intrinsics.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include "llvm/Transforms/Scalar.h"
39 STATISTIC(NumLoadAlignChanged
,
40 "Number of loads changed by alignment assumptions");
41 STATISTIC(NumStoreAlignChanged
,
42 "Number of stores changed by alignment assumptions");
43 STATISTIC(NumMemIntAlignChanged
,
44 "Number of memory intrinsics changed by alignment assumptions");
47 struct AlignmentFromAssumptions
: public FunctionPass
{
48 static char ID
; // Pass identification, replacement for typeid
49 AlignmentFromAssumptions() : FunctionPass(ID
) {
50 initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
53 bool runOnFunction(Function
&F
) override
;
55 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
56 AU
.addRequired
<AssumptionCacheTracker
>();
57 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
58 AU
.addRequired
<DominatorTreeWrapperPass
>();
61 AU
.addPreserved
<AAResultsWrapperPass
>();
62 AU
.addPreserved
<GlobalsAAWrapperPass
>();
63 AU
.addPreserved
<LoopInfoWrapperPass
>();
64 AU
.addPreserved
<DominatorTreeWrapperPass
>();
65 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
68 AlignmentFromAssumptionsPass Impl
;
72 char AlignmentFromAssumptions::ID
= 0;
73 static const char aip_name
[] = "Alignment from assumptions";
74 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions
, AA_NAME
,
75 aip_name
, false, false)
76 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
77 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
78 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
79 INITIALIZE_PASS_END(AlignmentFromAssumptions
, AA_NAME
,
80 aip_name
, false, false)
82 FunctionPass
*llvm::createAlignmentFromAssumptionsPass() {
83 return new AlignmentFromAssumptions();
86 // Given an expression for the (constant) alignment, AlignSCEV, and an
87 // expression for the displacement between a pointer and the aligned address,
88 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
89 // to a constant. Using SCEV to compute alignment handles the case where
90 // DiffSCEV is a recurrence with constant start such that the aligned offset
91 // is constant. e.g. {16,+,32} % 32 -> 16.
92 static unsigned getNewAlignmentDiff(const SCEV
*DiffSCEV
,
93 const SCEV
*AlignSCEV
,
94 ScalarEvolution
*SE
) {
95 // DiffUnits = Diff % int64_t(Alignment)
96 const SCEV
*DiffUnitsSCEV
= SE
->getURemExpr(DiffSCEV
, AlignSCEV
);
98 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV
<< " is "
99 << *DiffUnitsSCEV
<< " (diff: " << *DiffSCEV
<< ")\n");
101 if (const SCEVConstant
*ConstDUSCEV
=
102 dyn_cast
<SCEVConstant
>(DiffUnitsSCEV
)) {
103 int64_t DiffUnits
= ConstDUSCEV
->getValue()->getSExtValue();
105 // If the displacement is an exact multiple of the alignment, then the
106 // displaced pointer has the same alignment as the aligned pointer, so
107 // return the alignment value.
110 cast
<SCEVConstant
>(AlignSCEV
)->getValue()->getSExtValue();
112 // If the displacement is not an exact multiple, but the remainder is a
113 // constant, then return this remainder (but only if it is a power of 2).
114 uint64_t DiffUnitsAbs
= std::abs(DiffUnits
);
115 if (isPowerOf2_64(DiffUnitsAbs
))
116 return (unsigned) DiffUnitsAbs
;
122 // There is an address given by an offset OffSCEV from AASCEV which has an
123 // alignment AlignSCEV. Use that information, if possible, to compute a new
124 // alignment for Ptr.
125 static unsigned getNewAlignment(const SCEV
*AASCEV
, const SCEV
*AlignSCEV
,
126 const SCEV
*OffSCEV
, Value
*Ptr
,
127 ScalarEvolution
*SE
) {
128 const SCEV
*PtrSCEV
= SE
->getSCEV(Ptr
);
129 const SCEV
*DiffSCEV
= SE
->getMinusSCEV(PtrSCEV
, AASCEV
);
131 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
132 // sign-extended OffSCEV to i64, so make sure they agree again.
133 DiffSCEV
= SE
->getNoopOrSignExtend(DiffSCEV
, OffSCEV
->getType());
135 // What we really want to know is the overall offset to the aligned
136 // address. This address is displaced by the provided offset.
137 DiffSCEV
= SE
->getMinusSCEV(DiffSCEV
, OffSCEV
);
139 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr
<< " relative to "
140 << *AlignSCEV
<< " and offset " << *OffSCEV
141 << " using diff " << *DiffSCEV
<< "\n");
143 unsigned NewAlignment
= getNewAlignmentDiff(DiffSCEV
, AlignSCEV
, SE
);
144 LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment
<< "\n");
148 } else if (const SCEVAddRecExpr
*DiffARSCEV
=
149 dyn_cast
<SCEVAddRecExpr
>(DiffSCEV
)) {
150 // The relative offset to the alignment assumption did not yield a constant,
151 // but we should try harder: if we assume that a is 32-byte aligned, then in
152 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
153 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
154 // As a result, the new alignment will not be a constant, but can still
155 // be improved over the default (of 4) to 16.
157 const SCEV
*DiffStartSCEV
= DiffARSCEV
->getStart();
158 const SCEV
*DiffIncSCEV
= DiffARSCEV
->getStepRecurrence(*SE
);
160 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
161 << *DiffStartSCEV
<< " and inc " << *DiffIncSCEV
<< "\n");
163 // Now compute the new alignment using the displacement to the value in the
164 // first iteration, and also the alignment using the per-iteration delta.
165 // If these are the same, then use that answer. Otherwise, use the smaller
166 // one, but only if it divides the larger one.
167 NewAlignment
= getNewAlignmentDiff(DiffStartSCEV
, AlignSCEV
, SE
);
168 unsigned NewIncAlignment
= getNewAlignmentDiff(DiffIncSCEV
, AlignSCEV
, SE
);
170 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment
<< "\n");
171 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment
<< "\n");
173 if (!NewAlignment
|| !NewIncAlignment
) {
175 } else if (NewAlignment
> NewIncAlignment
) {
176 if (NewAlignment
% NewIncAlignment
== 0) {
177 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment
179 return NewIncAlignment
;
181 } else if (NewIncAlignment
> NewAlignment
) {
182 if (NewIncAlignment
% NewAlignment
== 0) {
183 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
187 } else if (NewIncAlignment
== NewAlignment
) {
188 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment
197 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst
*I
,
199 const SCEV
*&AlignSCEV
,
200 const SCEV
*&OffSCEV
) {
201 // An alignment assume must be a statement about the least-significant
202 // bits of the pointer being zero, possibly with some offset.
203 ICmpInst
*ICI
= dyn_cast
<ICmpInst
>(I
->getArgOperand(0));
207 // This must be an expression of the form: x & m == 0.
208 if (ICI
->getPredicate() != ICmpInst::ICMP_EQ
)
211 // Swap things around so that the RHS is 0.
212 Value
*CmpLHS
= ICI
->getOperand(0);
213 Value
*CmpRHS
= ICI
->getOperand(1);
214 const SCEV
*CmpLHSSCEV
= SE
->getSCEV(CmpLHS
);
215 const SCEV
*CmpRHSSCEV
= SE
->getSCEV(CmpRHS
);
216 if (CmpLHSSCEV
->isZero())
217 std::swap(CmpLHS
, CmpRHS
);
218 else if (!CmpRHSSCEV
->isZero())
221 BinaryOperator
*CmpBO
= dyn_cast
<BinaryOperator
>(CmpLHS
);
222 if (!CmpBO
|| CmpBO
->getOpcode() != Instruction::And
)
225 // Swap things around so that the right operand of the and is a constant
226 // (the mask); we cannot deal with variable masks.
227 Value
*AndLHS
= CmpBO
->getOperand(0);
228 Value
*AndRHS
= CmpBO
->getOperand(1);
229 const SCEV
*AndLHSSCEV
= SE
->getSCEV(AndLHS
);
230 const SCEV
*AndRHSSCEV
= SE
->getSCEV(AndRHS
);
231 if (isa
<SCEVConstant
>(AndLHSSCEV
)) {
232 std::swap(AndLHS
, AndRHS
);
233 std::swap(AndLHSSCEV
, AndRHSSCEV
);
236 const SCEVConstant
*MaskSCEV
= dyn_cast
<SCEVConstant
>(AndRHSSCEV
);
240 // The mask must have some trailing ones (otherwise the condition is
241 // trivial and tells us nothing about the alignment of the left operand).
242 unsigned TrailingOnes
= MaskSCEV
->getAPInt().countTrailingOnes();
246 // Cap the alignment at the maximum with which LLVM can deal (and make sure
247 // we don't overflow the shift).
249 TrailingOnes
= std::min(TrailingOnes
,
250 unsigned(sizeof(unsigned) * CHAR_BIT
- 1));
251 Alignment
= std::min(1u << TrailingOnes
, +Value::MaximumAlignment
);
253 Type
*Int64Ty
= Type::getInt64Ty(I
->getParent()->getParent()->getContext());
254 AlignSCEV
= SE
->getConstant(Int64Ty
, Alignment
);
256 // The LHS might be a ptrtoint instruction, or it might be the pointer
260 if (PtrToIntInst
*PToI
= dyn_cast
<PtrToIntInst
>(AndLHS
)) {
261 AAPtr
= PToI
->getPointerOperand();
262 OffSCEV
= SE
->getZero(Int64Ty
);
263 } else if (const SCEVAddExpr
* AndLHSAddSCEV
=
264 dyn_cast
<SCEVAddExpr
>(AndLHSSCEV
)) {
265 // Try to find the ptrtoint; subtract it and the rest is the offset.
266 for (SCEVAddExpr::op_iterator J
= AndLHSAddSCEV
->op_begin(),
267 JE
= AndLHSAddSCEV
->op_end(); J
!= JE
; ++J
)
268 if (const SCEVUnknown
*OpUnk
= dyn_cast
<SCEVUnknown
>(*J
))
269 if (PtrToIntInst
*PToI
= dyn_cast
<PtrToIntInst
>(OpUnk
->getValue())) {
270 AAPtr
= PToI
->getPointerOperand();
271 OffSCEV
= SE
->getMinusSCEV(AndLHSAddSCEV
, *J
);
279 // Sign extend the offset to 64 bits (so that it is like all of the other
281 unsigned OffSCEVBits
= OffSCEV
->getType()->getPrimitiveSizeInBits();
282 if (OffSCEVBits
< 64)
283 OffSCEV
= SE
->getSignExtendExpr(OffSCEV
, Int64Ty
);
284 else if (OffSCEVBits
> 64)
287 AAPtr
= AAPtr
->stripPointerCasts();
291 bool AlignmentFromAssumptionsPass::processAssumption(CallInst
*ACall
) {
293 const SCEV
*AlignSCEV
, *OffSCEV
;
294 if (!extractAlignmentInfo(ACall
, AAPtr
, AlignSCEV
, OffSCEV
))
297 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
298 // affect other users.
299 if (isa
<ConstantData
>(AAPtr
))
302 const SCEV
*AASCEV
= SE
->getSCEV(AAPtr
);
304 // Apply the assumption to all other users of the specified pointer.
305 SmallPtrSet
<Instruction
*, 32> Visited
;
306 SmallVector
<Instruction
*, 16> WorkList
;
307 for (User
*J
: AAPtr
->users()) {
311 if (Instruction
*K
= dyn_cast
<Instruction
>(J
))
312 if (isValidAssumeForContext(ACall
, K
, DT
))
313 WorkList
.push_back(K
);
316 while (!WorkList
.empty()) {
317 Instruction
*J
= WorkList
.pop_back_val();
319 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(J
)) {
320 unsigned NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
321 LI
->getPointerOperand(), SE
);
323 if (NewAlignment
> LI
->getAlignment()) {
324 LI
->setAlignment(MaybeAlign(NewAlignment
));
325 ++NumLoadAlignChanged
;
327 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(J
)) {
328 unsigned NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
329 SI
->getPointerOperand(), SE
);
331 if (NewAlignment
> SI
->getAlignment()) {
332 SI
->setAlignment(MaybeAlign(NewAlignment
));
333 ++NumStoreAlignChanged
;
335 } else if (MemIntrinsic
*MI
= dyn_cast
<MemIntrinsic
>(J
)) {
336 unsigned NewDestAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
339 LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment
<< "\n";);
340 if (NewDestAlignment
> MI
->getDestAlignment()) {
341 MI
->setDestAlignment(NewDestAlignment
);
342 ++NumMemIntAlignChanged
;
345 // For memory transfers, there is also a source alignment that
347 if (MemTransferInst
*MTI
= dyn_cast
<MemTransferInst
>(MI
)) {
348 unsigned NewSrcAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
349 MTI
->getSource(), SE
);
351 LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment
<< "\n";);
353 if (NewSrcAlignment
> MTI
->getSourceAlignment()) {
354 MTI
->setSourceAlignment(NewSrcAlignment
);
355 ++NumMemIntAlignChanged
;
360 // Now that we've updated that use of the pointer, look for other uses of
361 // the pointer to update.
363 for (User
*UJ
: J
->users()) {
364 Instruction
*K
= cast
<Instruction
>(UJ
);
365 if (!Visited
.count(K
) && isValidAssumeForContext(ACall
, K
, DT
))
366 WorkList
.push_back(K
);
373 bool AlignmentFromAssumptions::runOnFunction(Function
&F
) {
377 auto &AC
= getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
378 ScalarEvolution
*SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
379 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
381 return Impl
.runImpl(F
, AC
, SE
, DT
);
384 bool AlignmentFromAssumptionsPass::runImpl(Function
&F
, AssumptionCache
&AC
,
385 ScalarEvolution
*SE_
,
386 DominatorTree
*DT_
) {
390 bool Changed
= false;
391 for (auto &AssumeVH
: AC
.assumptions())
393 Changed
|= processAssumption(cast
<CallInst
>(AssumeVH
));
399 AlignmentFromAssumptionsPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
401 AssumptionCache
&AC
= AM
.getResult
<AssumptionAnalysis
>(F
);
402 ScalarEvolution
&SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
403 DominatorTree
&DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
404 if (!runImpl(F
, AC
, &SE
, &DT
))
405 return PreservedAnalyses::all();
407 PreservedAnalyses PA
;
408 PA
.preserveSet
<CFGAnalyses
>();
409 PA
.preserve
<AAManager
>();
410 PA
.preserve
<ScalarEvolutionAnalysis
>();
411 PA
.preserve
<GlobalsAA
>();