1 //===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass reassociates n-ary add expressions and eliminates the redundancy
10 // exposed by the reassociation.
12 // A motivating example:
14 // void foo(int a, int b) {
19 // An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify
26 // However, the Reassociate pass is unable to do that because it processes each
27 // instruction individually and believes (a + 2) + b is the best form according
28 // to its rank system.
30 // To address this limitation, NaryReassociate reassociates an expression in a
31 // form that reuses existing instructions. As a result, NaryReassociate can
32 // reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that
33 // (a + b) is computed before.
35 // NaryReassociate works as follows. For every instruction in the form of (a +
36 // b) + c, it checks whether a + c or b + c is already computed by a dominating
37 // instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b +
38 // c) + a and removes the redundancy accordingly. To efficiently look up whether
39 // an expression is computed before, we store each instruction seen and its SCEV
40 // into an SCEV-to-instruction map.
42 // Although the algorithm pattern-matches only ternary additions, it
43 // automatically handles many >3-ary expressions by walking through the function
44 // in the depth-first order. For example, given
49 // NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites
50 // ((a + c) + b) + d into ((a + c) + d) + b.
52 // Finally, the above dominator-based algorithm may need to be run multiple
53 // iterations before emitting optimal code. One source of this need is that we
54 // only split an operand when it is used only once. The above algorithm can
55 // eliminate an instruction and decrease the usage count of its operands. As a
56 // result, an instruction that previously had multiple uses may become a
57 // single-use instruction and thus eligible for split consideration. For
66 // In the first iteration, we cannot reassociate abc to ac+b because ab is used
67 // twice. However, we can reassociate ab2c to abc+b in the first iteration. As a
68 // result, ab2 becomes dead and ab will be used only once in the second
71 // Limitations and TODO items:
73 // 1) We only considers n-ary adds and muls for now. This should be extended
76 //===----------------------------------------------------------------------===//
78 #include "llvm/Transforms/Scalar/NaryReassociate.h"
79 #include "llvm/ADT/DepthFirstIterator.h"
80 #include "llvm/ADT/SmallVector.h"
81 #include "llvm/Analysis/AssumptionCache.h"
82 #include "llvm/Analysis/ScalarEvolution.h"
83 #include "llvm/Analysis/TargetLibraryInfo.h"
84 #include "llvm/Analysis/TargetTransformInfo.h"
85 #include "llvm/Transforms/Utils/Local.h"
86 #include "llvm/Analysis/ValueTracking.h"
87 #include "llvm/IR/BasicBlock.h"
88 #include "llvm/IR/Constants.h"
89 #include "llvm/IR/DataLayout.h"
90 #include "llvm/IR/DerivedTypes.h"
91 #include "llvm/IR/Dominators.h"
92 #include "llvm/IR/Function.h"
93 #include "llvm/IR/GetElementPtrTypeIterator.h"
94 #include "llvm/IR/IRBuilder.h"
95 #include "llvm/IR/InstrTypes.h"
96 #include "llvm/IR/Instruction.h"
97 #include "llvm/IR/Instructions.h"
98 #include "llvm/IR/Module.h"
99 #include "llvm/IR/Operator.h"
100 #include "llvm/IR/PatternMatch.h"
101 #include "llvm/IR/Type.h"
102 #include "llvm/IR/Value.h"
103 #include "llvm/IR/ValueHandle.h"
104 #include "llvm/Pass.h"
105 #include "llvm/Support/Casting.h"
106 #include "llvm/Support/ErrorHandling.h"
107 #include "llvm/Transforms/Scalar.h"
111 using namespace llvm
;
112 using namespace PatternMatch
;
114 #define DEBUG_TYPE "nary-reassociate"
118 class NaryReassociateLegacyPass
: public FunctionPass
{
122 NaryReassociateLegacyPass() : FunctionPass(ID
) {
123 initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry());
126 bool doInitialization(Module
&M
) override
{
130 bool runOnFunction(Function
&F
) override
;
132 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
133 AU
.addPreserved
<DominatorTreeWrapperPass
>();
134 AU
.addPreserved
<ScalarEvolutionWrapperPass
>();
135 AU
.addPreserved
<TargetLibraryInfoWrapperPass
>();
136 AU
.addRequired
<AssumptionCacheTracker
>();
137 AU
.addRequired
<DominatorTreeWrapperPass
>();
138 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
139 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
140 AU
.addRequired
<TargetTransformInfoWrapperPass
>();
141 AU
.setPreservesCFG();
145 NaryReassociatePass Impl
;
148 } // end anonymous namespace
150 char NaryReassociateLegacyPass::ID
= 0;
152 INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass
, "nary-reassociate",
153 "Nary reassociation", false, false)
154 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker
)
155 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
156 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass
)
157 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
158 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass
)
159 INITIALIZE_PASS_END(NaryReassociateLegacyPass
, "nary-reassociate",
160 "Nary reassociation", false, false)
162 FunctionPass
*llvm::createNaryReassociatePass() {
163 return new NaryReassociateLegacyPass();
166 bool NaryReassociateLegacyPass::runOnFunction(Function
&F
) {
170 auto *AC
= &getAnalysis
<AssumptionCacheTracker
>().getAssumptionCache(F
);
171 auto *DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
172 auto *SE
= &getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
173 auto *TLI
= &getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
174 auto *TTI
= &getAnalysis
<TargetTransformInfoWrapperPass
>().getTTI(F
);
176 return Impl
.runImpl(F
, AC
, DT
, SE
, TLI
, TTI
);
179 PreservedAnalyses
NaryReassociatePass::run(Function
&F
,
180 FunctionAnalysisManager
&AM
) {
181 auto *AC
= &AM
.getResult
<AssumptionAnalysis
>(F
);
182 auto *DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
183 auto *SE
= &AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
184 auto *TLI
= &AM
.getResult
<TargetLibraryAnalysis
>(F
);
185 auto *TTI
= &AM
.getResult
<TargetIRAnalysis
>(F
);
187 if (!runImpl(F
, AC
, DT
, SE
, TLI
, TTI
))
188 return PreservedAnalyses::all();
190 PreservedAnalyses PA
;
191 PA
.preserveSet
<CFGAnalyses
>();
192 PA
.preserve
<ScalarEvolutionAnalysis
>();
196 bool NaryReassociatePass::runImpl(Function
&F
, AssumptionCache
*AC_
,
197 DominatorTree
*DT_
, ScalarEvolution
*SE_
,
198 TargetLibraryInfo
*TLI_
,
199 TargetTransformInfo
*TTI_
) {
205 DL
= &F
.getParent()->getDataLayout();
207 bool Changed
= false, ChangedInThisIteration
;
209 ChangedInThisIteration
= doOneIteration(F
);
210 Changed
|= ChangedInThisIteration
;
211 } while (ChangedInThisIteration
);
215 // Whitelist the instruction types NaryReassociate handles for now.
216 static bool isPotentiallyNaryReassociable(Instruction
*I
) {
217 switch (I
->getOpcode()) {
218 case Instruction::Add
:
219 case Instruction::GetElementPtr
:
220 case Instruction::Mul
:
227 bool NaryReassociatePass::doOneIteration(Function
&F
) {
228 bool Changed
= false;
230 // Process the basic blocks in a depth first traversal of the dominator
231 // tree. This order ensures that all bases of a candidate are in Candidates
232 // when we process it.
233 for (const auto Node
: depth_first(DT
)) {
234 BasicBlock
*BB
= Node
->getBlock();
235 for (auto I
= BB
->begin(); I
!= BB
->end(); ++I
) {
236 if (SE
->isSCEVable(I
->getType()) && isPotentiallyNaryReassociable(&*I
)) {
237 const SCEV
*OldSCEV
= SE
->getSCEV(&*I
);
238 if (Instruction
*NewI
= tryReassociate(&*I
)) {
240 SE
->forgetValue(&*I
);
241 I
->replaceAllUsesWith(NewI
);
242 WeakVH NewIExist
= NewI
;
243 // If SeenExprs/NewIExist contains I's WeakTrackingVH/WeakVH, that
244 // entry will be replaced with nullptr if deleted.
245 RecursivelyDeleteTriviallyDeadInstructions(&*I
, TLI
);
247 // Rare occation where the new instruction (NewI) have been removed,
248 // probably due to parts of the input code was dead from the
249 // beginning, reset the iterator and start over from the beginning
253 I
= NewI
->getIterator();
255 // Add the rewritten instruction to SeenExprs; the original instruction
257 const SCEV
*NewSCEV
= SE
->getSCEV(&*I
);
258 SeenExprs
[NewSCEV
].push_back(WeakTrackingVH(&*I
));
259 // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I)
260 // is equivalent to I. However, ScalarEvolution::getSCEV may
261 // weaken nsw causing NewSCEV not to equal OldSCEV. For example, suppose
263 // I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4
265 // NewI = &a[sext(i)] + sext(j).
267 // ScalarEvolution computes
268 // getSCEV(I) = a + 4 * sext(i + j)
269 // getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j)
270 // which are different SCEVs.
272 // To alleviate this issue of ScalarEvolution not always capturing
273 // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can
274 // map both SCEV before and after tryReassociate(I) to I.
276 // This improvement is exercised in @reassociate_gep_nsw in nary-gep.ll.
277 if (NewSCEV
!= OldSCEV
)
278 SeenExprs
[OldSCEV
].push_back(WeakTrackingVH(&*I
));
285 Instruction
*NaryReassociatePass::tryReassociate(Instruction
*I
) {
286 switch (I
->getOpcode()) {
287 case Instruction::Add
:
288 case Instruction::Mul
:
289 return tryReassociateBinaryOp(cast
<BinaryOperator
>(I
));
290 case Instruction::GetElementPtr
:
291 return tryReassociateGEP(cast
<GetElementPtrInst
>(I
));
293 llvm_unreachable("should be filtered out by isPotentiallyNaryReassociable");
297 static bool isGEPFoldable(GetElementPtrInst
*GEP
,
298 const TargetTransformInfo
*TTI
) {
299 SmallVector
<const Value
*, 4> Indices
;
300 for (auto I
= GEP
->idx_begin(); I
!= GEP
->idx_end(); ++I
)
301 Indices
.push_back(*I
);
302 return TTI
->getGEPCost(GEP
->getSourceElementType(), GEP
->getPointerOperand(),
303 Indices
) == TargetTransformInfo::TCC_Free
;
306 Instruction
*NaryReassociatePass::tryReassociateGEP(GetElementPtrInst
*GEP
) {
307 // Not worth reassociating GEP if it is foldable.
308 if (isGEPFoldable(GEP
, TTI
))
311 gep_type_iterator GTI
= gep_type_begin(*GEP
);
312 for (unsigned I
= 1, E
= GEP
->getNumOperands(); I
!= E
; ++I
, ++GTI
) {
313 if (GTI
.isSequential()) {
314 if (auto *NewGEP
= tryReassociateGEPAtIndex(GEP
, I
- 1,
315 GTI
.getIndexedType())) {
323 bool NaryReassociatePass::requiresSignExtension(Value
*Index
,
324 GetElementPtrInst
*GEP
) {
325 unsigned PointerSizeInBits
=
326 DL
->getPointerSizeInBits(GEP
->getType()->getPointerAddressSpace());
327 return cast
<IntegerType
>(Index
->getType())->getBitWidth() < PointerSizeInBits
;
331 NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst
*GEP
,
332 unsigned I
, Type
*IndexedType
) {
333 Value
*IndexToSplit
= GEP
->getOperand(I
+ 1);
334 if (SExtInst
*SExt
= dyn_cast
<SExtInst
>(IndexToSplit
)) {
335 IndexToSplit
= SExt
->getOperand(0);
336 } else if (ZExtInst
*ZExt
= dyn_cast
<ZExtInst
>(IndexToSplit
)) {
337 // zext can be treated as sext if the source is non-negative.
338 if (isKnownNonNegative(ZExt
->getOperand(0), *DL
, 0, AC
, GEP
, DT
))
339 IndexToSplit
= ZExt
->getOperand(0);
342 if (AddOperator
*AO
= dyn_cast
<AddOperator
>(IndexToSplit
)) {
343 // If the I-th index needs sext and the underlying add is not equipped with
344 // nsw, we cannot split the add because
345 // sext(LHS + RHS) != sext(LHS) + sext(RHS).
346 if (requiresSignExtension(IndexToSplit
, GEP
) &&
347 computeOverflowForSignedAdd(AO
, *DL
, AC
, GEP
, DT
) !=
348 OverflowResult::NeverOverflows
)
351 Value
*LHS
= AO
->getOperand(0), *RHS
= AO
->getOperand(1);
352 // IndexToSplit = LHS + RHS.
353 if (auto *NewGEP
= tryReassociateGEPAtIndex(GEP
, I
, LHS
, RHS
, IndexedType
))
355 // Symmetrically, try IndexToSplit = RHS + LHS.
358 tryReassociateGEPAtIndex(GEP
, I
, RHS
, LHS
, IndexedType
))
366 NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst
*GEP
,
367 unsigned I
, Value
*LHS
,
368 Value
*RHS
, Type
*IndexedType
) {
369 // Look for GEP's closest dominator that has the same SCEV as GEP except that
370 // the I-th index is replaced with LHS.
371 SmallVector
<const SCEV
*, 4> IndexExprs
;
372 for (auto Index
= GEP
->idx_begin(); Index
!= GEP
->idx_end(); ++Index
)
373 IndexExprs
.push_back(SE
->getSCEV(*Index
));
374 // Replace the I-th index with LHS.
375 IndexExprs
[I
] = SE
->getSCEV(LHS
);
376 if (isKnownNonNegative(LHS
, *DL
, 0, AC
, GEP
, DT
) &&
377 DL
->getTypeSizeInBits(LHS
->getType()) <
378 DL
->getTypeSizeInBits(GEP
->getOperand(I
)->getType())) {
379 // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to
380 // zext if the source operand is proved non-negative. We should do that
381 // consistently so that CandidateExpr more likely appears before. See
382 // @reassociate_gep_assume for an example of this canonicalization.
384 SE
->getZeroExtendExpr(IndexExprs
[I
], GEP
->getOperand(I
)->getType());
386 const SCEV
*CandidateExpr
= SE
->getGEPExpr(cast
<GEPOperator
>(GEP
),
389 Value
*Candidate
= findClosestMatchingDominator(CandidateExpr
, GEP
);
390 if (Candidate
== nullptr)
393 IRBuilder
<> Builder(GEP
);
394 // Candidate does not necessarily have the same pointer type as GEP. Use
395 // bitcast or pointer cast to make sure they have the same type, so that the
396 // later RAUW doesn't complain.
397 Candidate
= Builder
.CreateBitOrPointerCast(Candidate
, GEP
->getType());
398 assert(Candidate
->getType() == GEP
->getType());
400 // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType)
401 uint64_t IndexedSize
= DL
->getTypeAllocSize(IndexedType
);
402 Type
*ElementType
= GEP
->getResultElementType();
403 uint64_t ElementSize
= DL
->getTypeAllocSize(ElementType
);
404 // Another less rare case: because I is not necessarily the last index of the
405 // GEP, the size of the type at the I-th index (IndexedSize) is not
406 // necessarily divisible by ElementSize. For example,
415 // sizeof(S) = 100 is indivisible by sizeof(int64) = 8.
417 // TODO: bail out on this case for now. We could emit uglygep.
418 if (IndexedSize
% ElementSize
!= 0)
421 // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0])));
422 Type
*IntPtrTy
= DL
->getIntPtrType(GEP
->getType());
423 if (RHS
->getType() != IntPtrTy
)
424 RHS
= Builder
.CreateSExtOrTrunc(RHS
, IntPtrTy
);
425 if (IndexedSize
!= ElementSize
) {
426 RHS
= Builder
.CreateMul(
427 RHS
, ConstantInt::get(IntPtrTy
, IndexedSize
/ ElementSize
));
429 GetElementPtrInst
*NewGEP
= cast
<GetElementPtrInst
>(
430 Builder
.CreateGEP(GEP
->getResultElementType(), Candidate
, RHS
));
431 NewGEP
->setIsInBounds(GEP
->isInBounds());
432 NewGEP
->takeName(GEP
);
436 Instruction
*NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator
*I
) {
437 Value
*LHS
= I
->getOperand(0), *RHS
= I
->getOperand(1);
438 // There is no need to reassociate 0.
439 if (SE
->getSCEV(I
)->isZero())
441 if (auto *NewI
= tryReassociateBinaryOp(LHS
, RHS
, I
))
443 if (auto *NewI
= tryReassociateBinaryOp(RHS
, LHS
, I
))
448 Instruction
*NaryReassociatePass::tryReassociateBinaryOp(Value
*LHS
, Value
*RHS
,
450 Value
*A
= nullptr, *B
= nullptr;
451 // To be conservative, we reassociate I only when it is the only user of (A op
453 if (LHS
->hasOneUse() && matchTernaryOp(I
, LHS
, A
, B
)) {
454 // I = (A op B) op RHS
455 // = (A op RHS) op B or (B op RHS) op A
456 const SCEV
*AExpr
= SE
->getSCEV(A
), *BExpr
= SE
->getSCEV(B
);
457 const SCEV
*RHSExpr
= SE
->getSCEV(RHS
);
458 if (BExpr
!= RHSExpr
) {
460 tryReassociatedBinaryOp(getBinarySCEV(I
, AExpr
, RHSExpr
), B
, I
))
463 if (AExpr
!= RHSExpr
) {
465 tryReassociatedBinaryOp(getBinarySCEV(I
, BExpr
, RHSExpr
), A
, I
))
472 Instruction
*NaryReassociatePass::tryReassociatedBinaryOp(const SCEV
*LHSExpr
,
475 // Look for the closest dominator LHS of I that computes LHSExpr, and replace
476 // I with LHS op RHS.
477 auto *LHS
= findClosestMatchingDominator(LHSExpr
, I
);
481 Instruction
*NewI
= nullptr;
482 switch (I
->getOpcode()) {
483 case Instruction::Add
:
484 NewI
= BinaryOperator::CreateAdd(LHS
, RHS
, "", I
);
486 case Instruction::Mul
:
487 NewI
= BinaryOperator::CreateMul(LHS
, RHS
, "", I
);
490 llvm_unreachable("Unexpected instruction.");
496 bool NaryReassociatePass::matchTernaryOp(BinaryOperator
*I
, Value
*V
,
497 Value
*&Op1
, Value
*&Op2
) {
498 switch (I
->getOpcode()) {
499 case Instruction::Add
:
500 return match(V
, m_Add(m_Value(Op1
), m_Value(Op2
)));
501 case Instruction::Mul
:
502 return match(V
, m_Mul(m_Value(Op1
), m_Value(Op2
)));
504 llvm_unreachable("Unexpected instruction.");
509 const SCEV
*NaryReassociatePass::getBinarySCEV(BinaryOperator
*I
,
512 switch (I
->getOpcode()) {
513 case Instruction::Add
:
514 return SE
->getAddExpr(LHS
, RHS
);
515 case Instruction::Mul
:
516 return SE
->getMulExpr(LHS
, RHS
);
518 llvm_unreachable("Unexpected instruction.");
524 NaryReassociatePass::findClosestMatchingDominator(const SCEV
*CandidateExpr
,
525 Instruction
*Dominatee
) {
526 auto Pos
= SeenExprs
.find(CandidateExpr
);
527 if (Pos
== SeenExprs
.end())
530 auto &Candidates
= Pos
->second
;
531 // Because we process the basic blocks in pre-order of the dominator tree, a
532 // candidate that doesn't dominate the current instruction won't dominate any
533 // future instruction either. Therefore, we pop it out of the stack. This
534 // optimization makes the algorithm O(n).
535 while (!Candidates
.empty()) {
536 // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's
539 if (Value
*Candidate
= Candidates
.back()) {
540 Instruction
*CandidateInstruction
= cast
<Instruction
>(Candidate
);
541 if (DT
->dominates(CandidateInstruction
, Dominatee
))
542 return CandidateInstruction
;
544 Candidates
.pop_back();