1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 // Set Load/Store Alignments From Assumptions
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
16 //===----------------------------------------------------------------------===//
18 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/Analysis/AliasAnalysis.h"
22 #include "llvm/Analysis/AssumptionCache.h"
23 #include "llvm/Analysis/GlobalsModRef.h"
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "alignment-from-assumptions"
37 STATISTIC(NumLoadAlignChanged
,
38 "Number of loads changed by alignment assumptions");
39 STATISTIC(NumStoreAlignChanged
,
40 "Number of stores changed by alignment assumptions");
41 STATISTIC(NumMemIntAlignChanged
,
42 "Number of memory intrinsics changed by alignment assumptions");
44 // Given an expression for the (constant) alignment, AlignSCEV, and an
45 // expression for the displacement between a pointer and the aligned address,
46 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
47 // to a constant. Using SCEV to compute alignment handles the case where
48 // DiffSCEV is a recurrence with constant start such that the aligned offset
49 // is constant. e.g. {16,+,32} % 32 -> 16.
50 static MaybeAlign
getNewAlignmentDiff(const SCEV
*DiffSCEV
,
51 const SCEV
*AlignSCEV
,
52 ScalarEvolution
*SE
) {
53 // DiffUnits = Diff % int64_t(Alignment)
54 const SCEV
*DiffUnitsSCEV
= SE
->getURemExpr(DiffSCEV
, AlignSCEV
);
56 LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV
<< " is "
57 << *DiffUnitsSCEV
<< " (diff: " << *DiffSCEV
<< ")\n");
59 if (const SCEVConstant
*ConstDUSCEV
=
60 dyn_cast
<SCEVConstant
>(DiffUnitsSCEV
)) {
61 int64_t DiffUnits
= ConstDUSCEV
->getValue()->getSExtValue();
63 // If the displacement is an exact multiple of the alignment, then the
64 // displaced pointer has the same alignment as the aligned pointer, so
65 // return the alignment value.
67 return cast
<SCEVConstant
>(AlignSCEV
)->getValue()->getAlignValue();
69 // If the displacement is not an exact multiple, but the remainder is a
70 // constant, then return this remainder (but only if it is a power of 2).
71 uint64_t DiffUnitsAbs
= std::abs(DiffUnits
);
72 if (isPowerOf2_64(DiffUnitsAbs
))
73 return Align(DiffUnitsAbs
);
79 // There is an address given by an offset OffSCEV from AASCEV which has an
80 // alignment AlignSCEV. Use that information, if possible, to compute a new
82 static Align
getNewAlignment(const SCEV
*AASCEV
, const SCEV
*AlignSCEV
,
83 const SCEV
*OffSCEV
, Value
*Ptr
,
84 ScalarEvolution
*SE
) {
85 const SCEV
*PtrSCEV
= SE
->getSCEV(Ptr
);
87 const SCEV
*DiffSCEV
= SE
->getMinusSCEV(PtrSCEV
, AASCEV
);
88 if (isa
<SCEVCouldNotCompute
>(DiffSCEV
))
91 // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
92 // sign-extended OffSCEV to i64, so make sure they agree again.
93 DiffSCEV
= SE
->getNoopOrSignExtend(DiffSCEV
, OffSCEV
->getType());
95 // What we really want to know is the overall offset to the aligned
96 // address. This address is displaced by the provided offset.
97 DiffSCEV
= SE
->getAddExpr(DiffSCEV
, OffSCEV
);
99 LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr
<< " relative to "
100 << *AlignSCEV
<< " and offset " << *OffSCEV
101 << " using diff " << *DiffSCEV
<< "\n");
103 if (MaybeAlign NewAlignment
= getNewAlignmentDiff(DiffSCEV
, AlignSCEV
, SE
)) {
104 LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment
) << "\n");
105 return *NewAlignment
;
108 if (const SCEVAddRecExpr
*DiffARSCEV
= dyn_cast
<SCEVAddRecExpr
>(DiffSCEV
)) {
109 // The relative offset to the alignment assumption did not yield a constant,
110 // but we should try harder: if we assume that a is 32-byte aligned, then in
111 // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
112 // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
113 // As a result, the new alignment will not be a constant, but can still
114 // be improved over the default (of 4) to 16.
116 const SCEV
*DiffStartSCEV
= DiffARSCEV
->getStart();
117 const SCEV
*DiffIncSCEV
= DiffARSCEV
->getStepRecurrence(*SE
);
119 LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
120 << *DiffStartSCEV
<< " and inc " << *DiffIncSCEV
<< "\n");
122 // Now compute the new alignment using the displacement to the value in the
123 // first iteration, and also the alignment using the per-iteration delta.
124 // If these are the same, then use that answer. Otherwise, use the smaller
125 // one, but only if it divides the larger one.
126 MaybeAlign NewAlignment
= getNewAlignmentDiff(DiffStartSCEV
, AlignSCEV
, SE
);
127 MaybeAlign NewIncAlignment
=
128 getNewAlignmentDiff(DiffIncSCEV
, AlignSCEV
, SE
);
130 LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment
)
132 LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment
)
135 if (!NewAlignment
|| !NewIncAlignment
)
138 const Align NewAlign
= *NewAlignment
;
139 const Align NewIncAlign
= *NewIncAlignment
;
140 if (NewAlign
> NewIncAlign
) {
141 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
142 << DebugStr(NewIncAlign
) << "\n");
145 if (NewIncAlign
> NewAlign
) {
146 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign
)
150 assert(NewIncAlign
== NewAlign
);
151 LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign
)
159 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst
*I
,
162 const SCEV
*&AlignSCEV
,
163 const SCEV
*&OffSCEV
) {
164 Type
*Int64Ty
= Type::getInt64Ty(I
->getContext());
165 OperandBundleUse AlignOB
= I
->getOperandBundleAt(Idx
);
166 if (AlignOB
.getTagName() != "align")
168 assert(AlignOB
.Inputs
.size() >= 2);
169 AAPtr
= AlignOB
.Inputs
[0].get();
170 // TODO: Consider accumulating the offset to the base.
171 AAPtr
= AAPtr
->stripPointerCastsSameRepresentation();
172 AlignSCEV
= SE
->getSCEV(AlignOB
.Inputs
[1].get());
173 AlignSCEV
= SE
->getTruncateOrZeroExtend(AlignSCEV
, Int64Ty
);
174 if (!isa
<SCEVConstant
>(AlignSCEV
))
175 // Added to suppress a crash because consumer doesn't expect non-constant
176 // alignments in the assume bundle. TODO: Consider generalizing caller.
178 if (!cast
<SCEVConstant
>(AlignSCEV
)->getAPInt().isPowerOf2())
179 // Only power of two alignments are supported.
181 if (AlignOB
.Inputs
.size() == 3)
182 OffSCEV
= SE
->getSCEV(AlignOB
.Inputs
[2].get());
184 OffSCEV
= SE
->getZero(Int64Ty
);
185 OffSCEV
= SE
->getTruncateOrZeroExtend(OffSCEV
, Int64Ty
);
189 bool AlignmentFromAssumptionsPass::processAssumption(CallInst
*ACall
,
192 const SCEV
*AlignSCEV
, *OffSCEV
;
193 if (!extractAlignmentInfo(ACall
, Idx
, AAPtr
, AlignSCEV
, OffSCEV
))
196 // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't
197 // affect other users.
198 if (isa
<ConstantData
>(AAPtr
))
201 const SCEV
*AASCEV
= SE
->getSCEV(AAPtr
);
203 // Apply the assumption to all other users of the specified pointer.
204 SmallPtrSet
<Instruction
*, 32> Visited
;
205 SmallVector
<Instruction
*, 16> WorkList
;
206 for (User
*J
: AAPtr
->users()) {
210 if (Instruction
*K
= dyn_cast
<Instruction
>(J
))
211 WorkList
.push_back(K
);
214 while (!WorkList
.empty()) {
215 Instruction
*J
= WorkList
.pop_back_val();
216 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(J
)) {
217 if (!isValidAssumeForContext(ACall
, J
, DT
))
219 Align NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
220 LI
->getPointerOperand(), SE
);
221 if (NewAlignment
> LI
->getAlign()) {
222 LI
->setAlignment(NewAlignment
);
223 ++NumLoadAlignChanged
;
225 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(J
)) {
226 if (!isValidAssumeForContext(ACall
, J
, DT
))
228 Align NewAlignment
= getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
,
229 SI
->getPointerOperand(), SE
);
230 if (NewAlignment
> SI
->getAlign()) {
231 SI
->setAlignment(NewAlignment
);
232 ++NumStoreAlignChanged
;
234 } else if (MemIntrinsic
*MI
= dyn_cast
<MemIntrinsic
>(J
)) {
235 if (!isValidAssumeForContext(ACall
, J
, DT
))
237 Align NewDestAlignment
=
238 getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
, MI
->getDest(), SE
);
240 LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment
)
242 if (NewDestAlignment
> *MI
->getDestAlign()) {
243 MI
->setDestAlignment(NewDestAlignment
);
244 ++NumMemIntAlignChanged
;
247 // For memory transfers, there is also a source alignment that
249 if (MemTransferInst
*MTI
= dyn_cast
<MemTransferInst
>(MI
)) {
250 Align NewSrcAlignment
=
251 getNewAlignment(AASCEV
, AlignSCEV
, OffSCEV
, MTI
->getSource(), SE
);
253 LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment
)
256 if (NewSrcAlignment
> *MTI
->getSourceAlign()) {
257 MTI
->setSourceAlignment(NewSrcAlignment
);
258 ++NumMemIntAlignChanged
;
263 // Now that we've updated that use of the pointer, look for other uses of
264 // the pointer to update.
266 if (isa
<GetElementPtrInst
>(J
) || isa
<PHINode
>(J
))
267 for (auto &U
: J
->uses()) {
268 if (U
->getType()->isPointerTy()) {
269 Instruction
*K
= cast
<Instruction
>(U
.getUser());
270 StoreInst
*SI
= dyn_cast
<StoreInst
>(K
);
271 if (SI
&& SI
->getPointerOperandIndex() != U
.getOperandNo())
273 if (!Visited
.count(K
))
274 WorkList
.push_back(K
);
282 bool AlignmentFromAssumptionsPass::runImpl(Function
&F
, AssumptionCache
&AC
,
283 ScalarEvolution
*SE_
,
284 DominatorTree
*DT_
) {
288 bool Changed
= false;
289 for (auto &AssumeVH
: AC
.assumptions())
291 CallInst
*Call
= cast
<CallInst
>(AssumeVH
);
292 for (unsigned Idx
= 0; Idx
< Call
->getNumOperandBundles(); Idx
++)
293 Changed
|= processAssumption(Call
, Idx
);
300 AlignmentFromAssumptionsPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
302 AssumptionCache
&AC
= AM
.getResult
<AssumptionAnalysis
>(F
);
303 ScalarEvolution
&SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
304 DominatorTree
&DT
= AM
.getResult
<DominatorTreeAnalysis
>(F
);
305 if (!runImpl(F
, AC
, &SE
, &DT
))
306 return PreservedAnalyses::all();
308 PreservedAnalyses PA
;
309 PA
.preserveSet
<CFGAnalyses
>();
310 PA
.preserve
<ScalarEvolutionAnalysis
>();